Connector MVP now working
This commit is contained in:
Родитель
7f9ad3173c
Коммит
ea376c4f24
|
@ -27,6 +27,7 @@ import com.microsoft.azure.sqldb.spark.bulkcopy.BulkCopyMetadata
|
|||
import com.microsoft.azure.sqldb.spark.connect._
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import com.microsoft.azure.sqldb.spark.connect.ConnectionUtils._
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
@ -40,7 +41,7 @@ private[spark] object SQLWriter extends Logging {
|
|||
var subset = false //This variable identifies whether you want to write to all columns of the SQL table or just select few.
|
||||
var DRIVER_NAME: String = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
|
||||
var connection: Connection = _
|
||||
var ps: PreparedStatement = null
|
||||
//var ps: PreparedStatement = null
|
||||
|
||||
override def toString: String = "SQLWriter"
|
||||
|
||||
|
@ -72,20 +73,27 @@ private[spark] object SQLWriter extends Logging {
|
|||
if(connection.isClosed){
|
||||
val test = 1
|
||||
}
|
||||
// connection.setCatalog(db)
|
||||
|
||||
|
||||
//Using Bulk Copy
|
||||
val bulkCopyConfig = Config(Map(
|
||||
//TODO: Using Bulk Copy
|
||||
/*val bulkCopyConfig = Config(Map(
|
||||
"url" -> url,
|
||||
"databaseName" -> writeConfig.get[String](SqlDBConfig.DatabaseName).get,
|
||||
"user" -> writeConfig.get[String](SqlDBConfig.User).get,
|
||||
"password" -> writeConfig.get[String](SqlDBConfig.DatabaseName).get,
|
||||
"dbTable" -> table
|
||||
))
|
||||
|
||||
try{
|
||||
// data.bulkCopyToSqlDB()
|
||||
data.bulkCopyToSqlDB()
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
log.error("Error writing batch data to SQL DB. Error details: "+ e)
|
||||
throw e
|
||||
}
|
||||
*/
|
||||
|
||||
/* Getting Column information */
|
||||
try{
|
||||
|
||||
val schema = queryExecution.analyzed.output
|
||||
var schemaDatatype = new ListBuffer[String]()
|
||||
var colNames = ""
|
||||
|
@ -99,46 +107,45 @@ private[spark] object SQLWriter extends Logging {
|
|||
})
|
||||
}
|
||||
|
||||
//TODO: Option to overwrite
|
||||
var sql = "INSERT INTO " + table + " (" + colNames.substring(0, colNames.length-1) + ")" + " VALUES (" + values.substring(0, values.length-1) + ");"
|
||||
var sql = "INSERT INTO " + table + " (" + colNames.substring(0, colNames.length-1) + " , partitionid" + ")" + " VALUES (" + values.substring(0, values.length-1) + ",?" + ");"
|
||||
|
||||
ps = connection.prepareStatement(sql)
|
||||
|
||||
queryExecution.toRdd.foreachPartition(iter => {
|
||||
iter.foreach(row => {
|
||||
if(ps != null) {
|
||||
try{
|
||||
var i = 0
|
||||
for(e <- schemaDatatype) {
|
||||
val testVar = row.getString(i)
|
||||
e match {
|
||||
case "ByteType" => ps.setByte(i+1, row.getByte(i))
|
||||
case "ShortType" => ps.setShort(i+1,row.getShort(i))
|
||||
case "IntegerType" => ps.setInt(i+1, row.getInt(i))
|
||||
case "LongType" => ps.setLong(i+1, row.getLong(i))
|
||||
case "FloatType" => ps.setFloat(i+1, row.getFloat(i))
|
||||
case "DoubleType" => ps.setDouble(i+1, row.getDouble(i))
|
||||
// case "DecimalType" => statement.setBigDecimal(i, ) //TODO: try to use getAccessor and find a similar method in statement
|
||||
case "StringType" => ps.setString(i+1, row.getString(i))
|
||||
case "BinaryType" => ps.setBytes(i+1, row.getBinary(i))
|
||||
case "BooleanType" => ps.setBoolean(i+1, row.getBoolean(i))
|
||||
// case "TimestamType" => statement.setTimestamp(i+1, row.get.getTimestamp(i))
|
||||
// case "DateType" => statement.setDate(i+1, row.getDate(i))
|
||||
}
|
||||
i += 1
|
||||
queryExecution.toRdd.foreachPartition(iter => {
|
||||
val ps = connection.prepareStatement(sql)
|
||||
iter.foreach(row => {
|
||||
val pid = TaskContext.get().partitionId()
|
||||
if(ps != null) {
|
||||
try{
|
||||
var i = 0
|
||||
for(e <- schemaDatatype) {
|
||||
val testVar = row.getString(i)
|
||||
println("Value: " + testVar + " ; i: " + i)
|
||||
e match {
|
||||
case "ByteType" => ps.setByte(i+1, row.getByte(i))
|
||||
case "ShortType" => ps.setShort(i+1,row.getShort(i))
|
||||
case "IntegerType" => ps.setInt(i+1, row.getInt(i))
|
||||
case "LongType" => ps.setLong(i+1, row.getLong(i))
|
||||
case "FloatType" => ps.setFloat(i+1, row.getFloat(i))
|
||||
case "DoubleType" => ps.setDouble(i+1, row.getDouble(i))
|
||||
// case "DecimalType" => statement.setBigDecimal(i, ) //TODO: try to use getAccessor and find a similar method in statement
|
||||
case "StringType" => ps.setString(i+1, row.getString(i))
|
||||
case "BinaryType" => ps.setBytes(i+1, row.getBinary(i))
|
||||
case "BooleanType" => ps.setBoolean(i+1, row.getBoolean(i))
|
||||
// case "TimestamType" => statement.setTimestamp(i+1, row.get.getTimestamp(i))
|
||||
// case "DateType" => statement.setDate(i+1, row.getDate(i))
|
||||
}
|
||||
ps.execute()
|
||||
} catch {
|
||||
case e: SQLException => log.error("Error writing to SQL DB on row: " + row.toString())
|
||||
throw e //TODO: Give users the option to abort or continue
|
||||
|
||||
i += 1
|
||||
}
|
||||
ps.setInt(2, pid)
|
||||
ps.execute()
|
||||
} catch {
|
||||
case e: SQLException => log.error("Error writing to SQL DB on row: " + row.toString())
|
||||
throw e //TODO: Give users the option to abort or continue
|
||||
}
|
||||
|
||||
|
||||
//streamToDB(sql, row, schemaDatatype)
|
||||
|
||||
})
|
||||
}
|
||||
//streamToDB(sql, row, schemaDatatype)
|
||||
})
|
||||
})
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
log.error("Error writing batch data to SQL DB. Error details: "+ e)
|
||||
|
@ -182,14 +189,14 @@ private[spark] object SQLWriter extends Logging {
|
|||
This method checks to see if the table exists in SQL. If it doesn't, it creates the table. This method also ensures that the data types of the data frame are compatible with that of Azure SQL database. If they aren't, it converts them and returns the converted data frame
|
||||
*/
|
||||
def streamToDB(
|
||||
// statement: PreparedStatement,
|
||||
ps: PreparedStatement,
|
||||
sql: String,
|
||||
row: InternalRow,
|
||||
colDataTypes: ListBuffer[String]
|
||||
): Unit = {
|
||||
|
||||
if(ps == null) {
|
||||
ps = connection.prepareStatement(sql)
|
||||
// ps = connection.prepareStatement(sql)
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -52,10 +52,9 @@ class SQLSinkTest extends StreamTest with SharedSQLContext {
|
|||
)
|
||||
var stream: DataStreamWriter[Row] = null
|
||||
val input = MemoryStream[String]
|
||||
input.addData("1", "2", "3", "4")
|
||||
input.addData("1", "2", "3", "4", "5", "6", "7", "8")
|
||||
var df = input.toDF().withColumnRenamed("value", "input")
|
||||
|
||||
//val df = createReader("/sqltestdata")
|
||||
//TODO: Create Util functions to create SQL DB, table and truncate it and call the functions here
|
||||
withTempDir { checkpointDir =>
|
||||
config += ("checkpointLocation" -> checkpointDir.getCanonicalPath) //Adding Checkpoint Location
|
||||
|
@ -68,18 +67,12 @@ class SQLSinkTest extends StreamTest with SharedSQLContext {
|
|||
|
||||
try {
|
||||
failAfter(streamingTimeout){
|
||||
Thread.sleep(100)
|
||||
//streamStart.processAllAvailable()
|
||||
streamStart.processAllAvailable()
|
||||
}
|
||||
//var testDF = spark.read.sqlDB(Config(config)).as[String].select("value").map(_.toInt)
|
||||
checkDatasetUnorderly(spark.read.sqlDB(Config(config)).select($"input").as[String].map(_.toInt), 1, 2, 3, 4)
|
||||
checkDatasetUnorderly(spark.read.sqlDB(Config(config)).select($"input").as[String].map(_.toInt), 1, 2, 3, 4, 5, 6, 7, 8)
|
||||
} finally {
|
||||
streamStart.stop()
|
||||
}
|
||||
|
||||
//val df: DataFrame =
|
||||
|
||||
//TODO: Read data from SQL DB and check if it matches the data written (see EH implementation)
|
||||
}
|
||||
|
||||
/*private def createWriter(inputDF: DataFrame, sqlConfig: AzureSQLConfig, withOutputMode: Option[OutputMode]): StreamingQuery = {
|
||||
|
|
Загрузка…
Ссылка в новой задаче