Merge branch 'master' into master
This commit is contained in:
Коммит
c1a6c0e85d
51
pom.xml
51
pom.xml
|
@ -17,6 +17,8 @@
|
|||
</licenses>
|
||||
<properties>
|
||||
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
|
||||
<scala.binary.version>2.11</scala.binary.version>
|
||||
<spark.version>2.3.1</spark.version>
|
||||
</properties>
|
||||
<dependencies>
|
||||
<dependency>
|
||||
|
@ -24,6 +26,12 @@
|
|||
<artifactId>scala-library</artifactId>
|
||||
<version>2.11.8</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scala-lang</groupId>
|
||||
<artifactId>scala-reflect</artifactId>
|
||||
<version>${scala.version}</version>
|
||||
<scope>compile</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>junit</groupId>
|
||||
<artifactId>junit</artifactId>
|
||||
|
@ -38,7 +46,35 @@
|
|||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_2.11</artifactId>
|
||||
<version>2.2.1</version>
|
||||
<version>${spark.version}</version>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-core_${scala.binary.version}</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-sql_2.11</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-streaming_${scala.binary.version}</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.apache.spark</groupId>
|
||||
<artifactId>spark-catalyst_${scala.binary.version}</artifactId>
|
||||
<version>${spark.version}</version>
|
||||
<type>test-jar</type>
|
||||
<scope>test</scope>
|
||||
</dependency>
|
||||
<dependency>
|
||||
<groupId>org.scalactic</groupId>
|
||||
|
@ -81,6 +117,19 @@
|
|||
</distributionManagement>
|
||||
<build>
|
||||
<plugins>
|
||||
<plugin>
|
||||
<groupId>org.scala-tools</groupId>
|
||||
<artifactId>maven-scala-plugin</artifactId>
|
||||
<version>2.15.2</version>
|
||||
<executions>
|
||||
<execution>
|
||||
<goals>
|
||||
<goal>compile</goal>
|
||||
<goal>testCompile</goal>
|
||||
</goals>
|
||||
</execution>
|
||||
</executions>
|
||||
</plugin>
|
||||
<plugin>
|
||||
<groupId>org.codehaus.mojo</groupId>
|
||||
<artifactId>build-helper-maven-plugin</artifactId>
|
||||
|
|
|
@ -0,0 +1 @@
|
|||
com.microsoft.azure.sqldb.spark.sql.streaming.SQLSinkProvider
|
|
@ -53,7 +53,7 @@ private[spark] object ConnectionUtils {
|
|||
* @param url the string url without the JDBC prefix
|
||||
* @return the url with the added JDBC prefix
|
||||
*/
|
||||
def createJDBCUrl(url: String): String = SqlDBConfig.JDBCUrlPrefix + url
|
||||
def createJDBCUrl(url: String, port: Option[String] = None): String = SqlDBConfig.JDBCUrlPrefix + url + ":" + port.getOrElse("1433").toString()
|
||||
|
||||
/**
|
||||
* Gets a JDBC connection based on Config properties
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.microsoft.azure.sqldb.spark.sql.streaming
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.execution.streaming.Sink
|
||||
import org.apache.spark.sql.{DataFrame, SQLContext, SaveMode}
|
||||
|
||||
private[spark] class SqlSink(sqlContext: SQLContext,
|
||||
parameters: Map[String, String])
|
||||
extends Sink
|
||||
with Logging{
|
||||
|
||||
@volatile private var lastBatchId: Long = -1L
|
||||
|
||||
override def toString: String = "SQLSink"
|
||||
|
||||
override def addBatch(batchId: Long, data: DataFrame): Unit = {
|
||||
if (batchId <= lastBatchId) {
|
||||
log.info(s"Skipping already committed batch $batchId")
|
||||
} else {
|
||||
//val mode = if (parameters.overwrite) SaveMode.Overwrite else SaveMode.Append
|
||||
val mode = SaveMode.Append //TODOv2: Only Append mode supported for now
|
||||
SQLWriter.write(sqlContext.sparkSession, data, data.queryExecution, mode, parameters)
|
||||
lastBatchId = batchId
|
||||
}
|
||||
}
|
||||
}
|
|
@ -0,0 +1,45 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.microsoft.azure.sqldb.spark.sql.streaming
|
||||
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.SQLContext
|
||||
import org.apache.spark.sql.execution.streaming.Sink
|
||||
import org.apache.spark.sql.sources.{DataSourceRegister, StreamSinkProvider}
|
||||
import org.apache.spark.sql.streaming.OutputMode
|
||||
|
||||
import scala.util.Try
|
||||
|
||||
/**
|
||||
* The provider class for the [[SqlSink]].
|
||||
*/
|
||||
private[spark] class SQLSinkProvider extends DataSourceRegister
|
||||
with StreamSinkProvider
|
||||
with Logging {
|
||||
|
||||
|
||||
override def shortName(): String = "sqlserver"
|
||||
|
||||
override def createSink(sqlContext: SQLContext,
|
||||
parameters: Map[String, String],
|
||||
partitionColumns: Seq[String],
|
||||
outputMode: OutputMode): Sink = {
|
||||
|
||||
new SqlSink(sqlContext, parameters)
|
||||
}
|
||||
}
|
|
@ -0,0 +1,282 @@
|
|||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package com.microsoft.azure.sqldb.spark.sql.streaming
|
||||
|
||||
import java.sql.{Connection, DriverManager, PreparedStatement, SQLException}
|
||||
|
||||
import com.microsoft.azure.sqldb.spark.config.{Config, SqlDBConfig}
|
||||
import org.apache.spark.internal.Logging
|
||||
import org.apache.spark.sql.execution.QueryExecution
|
||||
import org.apache.spark.sql.{DataFrame, SaveMode, SparkSession}
|
||||
import com.microsoft.azure.sqldb.spark.bulkcopy.BulkCopyMetadata
|
||||
import com.microsoft.azure.sqldb.spark.connect._
|
||||
import org.apache.spark.sql.catalyst.InternalRow
|
||||
import com.microsoft.azure.sqldb.spark.connect.ConnectionUtils._
|
||||
import org.apache.spark.TaskContext
|
||||
import org.apache.spark.sql.catalyst.expressions.Attribute
|
||||
|
||||
import scala.collection.mutable.ListBuffer
|
||||
|
||||
/**
|
||||
* The [[SQLWriter]] class is used to write data from a batch query
|
||||
* or structured streaming query, given by a [[QueryExecution]], to Azure SQL Database or Azure SQL Data Warehouse.
|
||||
*/
|
||||
private[spark] object SQLWriter extends Logging {
|
||||
|
||||
var subset = false //This variable identifies whether you want to write to all columns of the SQL table or just select few.
|
||||
var DRIVER_NAME: String = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
|
||||
var connection: Connection = _
|
||||
//var ps: PreparedStatement = null
|
||||
|
||||
override def toString: String = "SQLWriter"
|
||||
|
||||
def write(
|
||||
sparkSession: SparkSession,
|
||||
data: DataFrame,
|
||||
queryExecution: QueryExecution,
|
||||
saveMode: SaveMode,
|
||||
parameters: Map[String, String]
|
||||
): Unit = {
|
||||
|
||||
//TODO: Clean this up and make it similar to that of SQL DB
|
||||
var writeConfig = Config(parameters)
|
||||
val url = writeConfig.get[String](SqlDBConfig.URL).get //TODO: If URL is not specified, try to construct one
|
||||
val db = writeConfig.get[String](SqlDBConfig.DatabaseName).get
|
||||
var createTable = false
|
||||
val table = writeConfig.get[String](SqlDBConfig.DBTable).getOrElse(
|
||||
//createTable = true
|
||||
throw new IllegalArgumentException("Table not found in DBTable in Config")
|
||||
)
|
||||
val user = writeConfig.get[String](SqlDBConfig.User).get
|
||||
val password = writeConfig.get[String](SqlDBConfig.Password).get
|
||||
val port = writeConfig.get[String](SqlDBConfig.PortNumber).getOrElse("1433")
|
||||
|
||||
//Using regular write
|
||||
Class.forName(DRIVER_NAME)
|
||||
connection = DriverManager.getConnection(createJDBCUrl(url, Some(port))+";database="+db, user, password)
|
||||
|
||||
//TODO: Using Bulk Copy
|
||||
/*val bulkCopyConfig = Config(Map(
|
||||
"url" -> url,
|
||||
"databaseName" -> writeConfig.get[String](SqlDBConfig.DatabaseName).get,
|
||||
"user" -> writeConfig.get[String](SqlDBConfig.User).get,
|
||||
"password" -> writeConfig.get[String](SqlDBConfig.DatabaseName).get,
|
||||
"dbTable" -> table
|
||||
))
|
||||
try{
|
||||
data.bulkCopyToSqlDB()
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
log.error("Error writing batch data to SQL DB. Error details: "+ e)
|
||||
throw e
|
||||
}
|
||||
*/
|
||||
|
||||
/* Getting Column information */
|
||||
try{
|
||||
|
||||
val schema = queryExecution.analyzed.output
|
||||
var schemaDatatype = new ListBuffer[String]()
|
||||
var colNames = ""
|
||||
var values = ""
|
||||
|
||||
if(schema.size > 0) {
|
||||
schema.foreach(col => {
|
||||
colNames += col.name + ","
|
||||
values += "?,"
|
||||
schemaDatatype+= col.dataType.toString()
|
||||
})
|
||||
}
|
||||
|
||||
//var sql = "INSERT INTO " + table + " (" + colNames.substring(0, colNames.length-1) + " , partitionid" + ")" + " VALUES (" + values.substring(0, values.length-1) + ",?" + ");"
|
||||
var sql = "INSERT INTO " + table + " (" + colNames.substring(0, colNames.length-1) + ")" + " VALUES (" + values.substring(0, values.length-1) + ");"
|
||||
|
||||
queryExecution.toRdd.foreachPartition(iter => {
|
||||
val ps = connection.prepareStatement(sql)
|
||||
iter.foreach(row => {
|
||||
val pid = TaskContext.get().partitionId()
|
||||
if(ps != null) {
|
||||
try{
|
||||
var i = 0
|
||||
for(e <- schemaDatatype) {
|
||||
val testVar = row.getString(i)
|
||||
println("Value: " + testVar + " ; i: " + i)
|
||||
e match {
|
||||
case "ByteType" => ps.setByte(i+1, row.getByte(i))
|
||||
case "ShortType" => ps.setShort(i+1,row.getShort(i))
|
||||
case "IntegerType" => ps.setInt(i+1, row.getInt(i))
|
||||
case "LongType" => ps.setLong(i+1, row.getLong(i))
|
||||
case "FloatType" => ps.setFloat(i+1, row.getFloat(i))
|
||||
case "DoubleType" => ps.setDouble(i+1, row.getDouble(i))
|
||||
// case "DecimalType" => statement.setBigDecimal(i, ) //TODO: try to use getAccessor and find a similar method in statement
|
||||
case "StringType" => ps.setString(i+1, row.getString(i))
|
||||
case "BinaryType" => ps.setBytes(i+1, row.getBinary(i))
|
||||
case "BooleanType" => ps.setBoolean(i+1, row.getBoolean(i))
|
||||
// case "TimestamType" => statement.setTimestamp(i+1, row.get.getTimestamp(i))
|
||||
// case "DateType" => statement.setDate(i+1, row.getDate(i))
|
||||
}
|
||||
i += 1
|
||||
}
|
||||
// ps.setInt(2, pid)
|
||||
ps.execute()
|
||||
} catch {
|
||||
case e: SQLException => log.error("Error writing to SQL DB on row: " + row.toString())
|
||||
throw e //TODO: Give users the option to abort or continue
|
||||
}
|
||||
}
|
||||
//streamToDB(sql, row, schemaDatatype)
|
||||
})
|
||||
})
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
log.error("Error writing batch data to SQL DB. Error details: "+ e)
|
||||
throw e
|
||||
} finally {
|
||||
// connection.close()
|
||||
}
|
||||
|
||||
|
||||
|
||||
/*
|
||||
val table = sqlConf.tableName
|
||||
val jdbcWrapper: SqlJDBCWrapper = new SqlJDBCWrapper
|
||||
|
||||
//TODO: Remove the error below.
|
||||
throw new SQLException("Currently only directCopy supported")
|
||||
|
||||
connection = jdbcWrapper.setupConnection(sqlConf.connectionString, sqlConf.username, sqlConf.password, Option(DRIVER_NAME)) //TODOv2: Hard-coding the driver name for now
|
||||
connection.setAutoCommit(false)
|
||||
val mappedData = dataMapper(connection, table, data) //TODO: Handle data type conversions smoothly & check table existence, checks column names and types and map them to dataframe columns
|
||||
val schema = mappedData.schema
|
||||
if(data.schema == mappedData.schema){ //TODOv2: Instead of this, read from the params, so we don't have to call jdbcWrapper or dataMapper.
|
||||
subset = true;
|
||||
}
|
||||
loadSqlData(connection, subset, sqlConf, mappedData)
|
||||
connection.commit()
|
||||
|
||||
//TODOv2: Provide the option to the user to define the columns they're writing to and/or column mapping
|
||||
//TODOv2: Handle creation of tables and append/overwrite of tables ; for v1, only append mode
|
||||
*/
|
||||
}
|
||||
|
||||
/*def createPreparedStatement(
|
||||
conn: Connection,
|
||||
schema: Seq[Attribute]
|
||||
): PreparedStatement = {
|
||||
|
||||
}*/
|
||||
|
||||
/*
|
||||
This method checks to see if the table exists in SQL. If it doesn't, it creates the table. This method also ensures that the data types of the data frame are compatible with that of Azure SQL database. If they aren't, it converts them and returns the converted data frame
|
||||
*/
|
||||
def streamToDB(
|
||||
ps: PreparedStatement,
|
||||
sql: String,
|
||||
row: InternalRow,
|
||||
colDataTypes: ListBuffer[String]
|
||||
): Unit = {
|
||||
|
||||
if(ps == null) {
|
||||
// ps = connection.prepareStatement(sql)
|
||||
}
|
||||
|
||||
|
||||
if(ps != null) {
|
||||
try{
|
||||
for(i <- 0 to colDataTypes.length-1) {
|
||||
colDataTypes(i) match {
|
||||
case "ByteType" => ps.setByte(i+1, row.getByte(i))
|
||||
case "ShortType" => ps.setShort(i+1,row.getShort(i))
|
||||
case "IntegerType" => ps.setInt(i+1, row.getInt(i))
|
||||
case "LongType" => ps.setLong(i+1, row.getLong(i))
|
||||
case "FloatType" => ps.setFloat(i+1, row.getFloat(i))
|
||||
case "DoubleType" => ps.setDouble(i+1, row.getDouble(i))
|
||||
// case "DecimalType" => statement.setBigDecimal(i, ) //TODO: try to use getAccessor and find a similar method in statement
|
||||
case "StringType" => ps.setString(i+1, row.getString(i))
|
||||
case "BinaryType" => ps.setBytes(i+1, row.getBinary(i))
|
||||
case "BooleanType" => ps.setBoolean(i+1, row.getBoolean(i))
|
||||
// case "TimestamType" => statement.setTimestamp(i+1, row.get.getTimestamp(i))
|
||||
// case "DateType" => statement.setDate(i+1, row.getDate(i))
|
||||
}
|
||||
}
|
||||
ps.execute()
|
||||
} catch {
|
||||
case e: SQLException => log.error("Error writing to SQL DB on row: " + row.toString())
|
||||
throw e //TODO: Give users the option to abort or continue
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
/*
|
||||
Prepares the Insert statement and calls the [[SqlJDBCWrapper.executeCmd]]
|
||||
*/
|
||||
/*
|
||||
def loadSqlData(
|
||||
conn: Connection,
|
||||
subset: Boolean,
|
||||
conf: AzureSQLConfig,
|
||||
data: DataFrame
|
||||
): Unit = {
|
||||
if(subset){
|
||||
var schemaStr:String = ""
|
||||
val schema = data.schema
|
||||
schema.fields.foreach{
|
||||
schemaStr += _.name + ","
|
||||
}
|
||||
schemaStr.substring(0,schemaStr.length-2)
|
||||
//val insertStatement = s"INSERT INTO $table ("+ schemaStr + ") VALUES ("
|
||||
//TODO: Handle append mode
|
||||
} else {
|
||||
val connectionProperties = new Properties()
|
||||
connectionProperties.put("user", conf.username)
|
||||
connectionProperties.put("password", conf.password)
|
||||
connectionProperties.put("driver", DRIVER_NAME)
|
||||
try{
|
||||
data.write.mode(SaveMode.Append).jdbc(conf.connectionString, conf.tableName, connectionProperties)
|
||||
} catch {
|
||||
case e: Exception =>
|
||||
log.error("Error writing batch data to SQL DB. Error details: "+ e)
|
||||
throw e
|
||||
}
|
||||
}
|
||||
} */
|
||||
}
|
||||
|
||||
|
||||
/*
|
||||
val jdbc_url = s"jdbc:sqlserver://${serverName}:${jdbcPort};database=${database};encrypt=true;trustServerCertificate=false;hostNameInCertificate=*.database.windows.net;loginTimeout=30;"
|
||||
|
||||
|
||||
//Creating Properties
|
||||
val connectionProperties = new Properties()
|
||||
connectionProperties.put("user", s"$username")
|
||||
connectionProperties.put("password", s"$password")
|
||||
connectionProperties.put("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
|
||||
|
||||
//TODO: Write data to SQL DB
|
||||
|
||||
val serverName = parameters.get(SERVER_KEY).map(_.trim)
|
||||
val portNumber = parameters.get(PORT_KEY).map(_.trim).flatMap(s => Try(s.toInt).toOption)
|
||||
val database = parameters.get(DB_KEY).map(_.trim)
|
||||
val username = parameters.get(USER_KEY).map(_.trim)
|
||||
val password = parameters.get(PWD_KEY).map(_.trim)
|
||||
val tableName = parameters.get(TABLE_KEY).map(_.trim)
|
||||
|
||||
|
||||
*/
|
|
@ -0,0 +1,118 @@
|
|||
package com.microsoft.azure.sqldb.spark.sql.streaming
|
||||
|
||||
import com.microsoft.azure.sqldb.spark.config.Config
|
||||
import com.microsoft.azure.sqldb.spark.connect._
|
||||
import org.apache.spark.sql.execution.streaming.MemoryStream
|
||||
import org.apache.spark.sql.streaming.{DataStreamWriter, StreamTest}
|
||||
import org.apache.spark.sql.test.SharedSQLContext
|
||||
import org.apache.spark.sql.types.{StringType, StructField, StructType}
|
||||
import org.apache.spark.sql.{DataFrame, Row}
|
||||
import org.scalatest.time.Span
|
||||
import org.scalatest.time.SpanSugar._
|
||||
import com.microsoft.azure.sqldb.spark.utils.SQLServerTestUtils
|
||||
|
||||
|
||||
class SQLSinkTest extends StreamTest with SharedSQLContext {
|
||||
import testImplicits._
|
||||
|
||||
override val streamingTimeout: Span = 30.seconds
|
||||
val url = "localhost"
|
||||
val database = "test1"
|
||||
val user = "test"
|
||||
val password = "test"
|
||||
val dbTable = "dbo.newtesttable"
|
||||
val portNum = "58502"
|
||||
|
||||
protected var SQLUtils: SQLServerTestUtils = _
|
||||
|
||||
override def beforeAll(): Unit = {
|
||||
super.beforeAll()
|
||||
SQLUtils = new SQLServerTestUtils
|
||||
}
|
||||
|
||||
override def afterAll(): Unit = {
|
||||
if(SQLUtils != null) {
|
||||
SQLUtils.dropAllTables()
|
||||
SQLUtils = null
|
||||
}
|
||||
super.afterAll()
|
||||
}
|
||||
|
||||
private def createReader(testFileLocation: String): DataFrame = {
|
||||
//TODO: Prep test data
|
||||
val testschema = StructType(
|
||||
StructField("input", StringType) :: Nil)
|
||||
spark.readStream.schema(testschema).json(testFileLocation)
|
||||
}
|
||||
|
||||
|
||||
test("Structured Streaming - Write to Azure SQL DB") {
|
||||
var config = Map(
|
||||
"url" -> url,
|
||||
"databaseName" -> database,
|
||||
"user" -> user,
|
||||
"password" -> password,
|
||||
"dbTable" -> dbTable,
|
||||
"portNumber" -> portNum
|
||||
)
|
||||
val columns = Map(
|
||||
"input" -> "nvarchar(10)"
|
||||
)
|
||||
var success = SQLUtils.createTable(config, columns)
|
||||
if(!success){
|
||||
fail("Table creation failed. Please check your config")
|
||||
}
|
||||
var stream: DataStreamWriter[Row] = null
|
||||
val input = MemoryStream[String]
|
||||
input.addData("1", "2", "3", "4", "5", "6", "7", "8", "9", "10")
|
||||
var df = input.toDF().withColumnRenamed("value", "input")
|
||||
|
||||
//TODO: Create Util functions to create SQL DB, table and truncate it and call the functions here
|
||||
withTempDir { checkpointDir =>
|
||||
config += ("checkpointLocation" -> checkpointDir.getCanonicalPath) //Adding Checkpoint Location
|
||||
stream = df.writeStream
|
||||
.format("sqlserver")
|
||||
.options(config)
|
||||
.outputMode("Append")
|
||||
}
|
||||
var streamStart = stream.start()
|
||||
|
||||
try {
|
||||
failAfter(streamingTimeout){
|
||||
streamStart.processAllAvailable()
|
||||
}
|
||||
checkDatasetUnorderly(spark.read.sqlDB(Config(config)).select($"input").as[String].map(_.toInt), 1, 2, 3, 4, 5, 6, 7, 8, 9, 10)
|
||||
} finally {
|
||||
success = SQLUtils.dropTable(config)
|
||||
if(!success){
|
||||
fail("Table deletion failed. Please check your config")
|
||||
}
|
||||
streamStart.stop()
|
||||
}
|
||||
}
|
||||
|
||||
/*private def createWriter(inputDF: DataFrame, sqlConfig: AzureSQLConfig, withOutputMode: Option[OutputMode]): StreamingQuery = {
|
||||
inputDF.writeStream.format("azuresql").option("directCopy","true").option("")
|
||||
}*/
|
||||
|
||||
|
||||
test("Structured Streaming - Incorrect username/password ; Ensure the right error surfaces"){
|
||||
|
||||
}
|
||||
|
||||
test("Structured Streaming - Incorrect Server Name ; Ensure the right error surfaces"){
|
||||
|
||||
}
|
||||
|
||||
test("Structured Streaming - Incorrect Database Name ; Ensure the right error surfaces"){
|
||||
|
||||
}
|
||||
|
||||
test("Structured Streaming - Incomplete options defined ; Ensure the right error surfaces"){
|
||||
|
||||
}
|
||||
|
||||
test("Structured Streaming - Table does not exist ; Ensure the right error surfaces"){
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,5 @@
|
|||
package com.microsoft.azure.sqldb.spark.sql.streaming
|
||||
|
||||
class WriteConfigSpec {
|
||||
|
||||
}
|
|
@ -0,0 +1,153 @@
|
|||
package com.microsoft.azure.sqldb.spark.utils
|
||||
|
||||
import java.sql.{Connection, DriverManager}
|
||||
|
||||
import com.microsoft.azure.sqldb.spark.config.{Config, SqlDBConfig}
|
||||
import com.microsoft.azure.sqldb.spark.connect.ConnectionUtils.{createConnectionProperties, createJDBCUrl}
|
||||
|
||||
private[spark] class SQLServerTestUtils {
|
||||
val DRIVER_NAME = "com.microsoft.sqlserver.jdbc.SQLServerDriver"
|
||||
|
||||
|
||||
private def getConnection(writeConfig: Config): Connection = {
|
||||
val url = writeConfig.get[String](SqlDBConfig.URL).get
|
||||
val db = writeConfig.get[String](SqlDBConfig.DatabaseName).get
|
||||
val properties = createConnectionProperties(writeConfig)
|
||||
val user = writeConfig.get[String](SqlDBConfig.User).get
|
||||
val password = writeConfig.get[String](SqlDBConfig.Password).get
|
||||
val port = writeConfig.get[String](SqlDBConfig.PortNumber).getOrElse("1433")
|
||||
if(db.equals(null) || url.equals(null)){
|
||||
return null
|
||||
}
|
||||
Class.forName(DRIVER_NAME)
|
||||
var conn: Connection = DriverManager.getConnection(createJDBCUrl(url, Some(port))+";database="+db, user, password)
|
||||
return conn
|
||||
}
|
||||
|
||||
def dropAllTables(): Unit = {
|
||||
//TODO
|
||||
}
|
||||
|
||||
def createTable(config: Map[String, String], columns: Map[String, String]): Boolean ={
|
||||
var writeConfig = Config(config)
|
||||
val table = writeConfig.get[String](SqlDBConfig.DBTable).getOrElse(null)
|
||||
if(table.equals(null)) {
|
||||
return false
|
||||
}
|
||||
val conn = getConnection(writeConfig)
|
||||
if(conn.equals(null)){
|
||||
return false
|
||||
}
|
||||
if(!dataTypeVerify(columns)){
|
||||
return false
|
||||
}
|
||||
var columnDef = ""
|
||||
for(column <- columns){
|
||||
columnDef += column._1 + " " + column._2 + ", "
|
||||
}
|
||||
var sql = "CREATE TABLE " + table + "(" + columnDef.substring(0, columnDef.length-1) + ");"
|
||||
try{
|
||||
val stmt = conn.createStatement()
|
||||
stmt.executeUpdate(sql)
|
||||
conn.close()
|
||||
} catch {
|
||||
case e: Exception => return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
def dropTable(config: Map[String, String]): Boolean ={
|
||||
var writeConfig = Config(config)
|
||||
val table = writeConfig.get[String](SqlDBConfig.DBTable).getOrElse(null)
|
||||
if(table.equals(null)) {
|
||||
return false
|
||||
}
|
||||
val conn = getConnection(writeConfig)
|
||||
if(conn.equals(null)){
|
||||
return false
|
||||
}
|
||||
|
||||
var sql = "DROP TABLE " + table + ";"
|
||||
try{
|
||||
val stmt = conn.createStatement()
|
||||
stmt.executeUpdate(sql)
|
||||
conn.close()
|
||||
} catch {
|
||||
case e: Exception => return false
|
||||
}
|
||||
|
||||
return true
|
||||
|
||||
}
|
||||
|
||||
def truncateTable(config: Map[String, String]): Boolean ={
|
||||
var writeConfig = Config(config)
|
||||
val table = writeConfig.get[String](SqlDBConfig.DBTable).getOrElse(null)
|
||||
if(table.equals(null)) {
|
||||
return false
|
||||
}
|
||||
val conn = getConnection(writeConfig)
|
||||
if(conn.equals(null)){
|
||||
return false
|
||||
}
|
||||
|
||||
var sql = "TRUNCATE TABLE " + table + ";"
|
||||
try{
|
||||
val stmt = conn.createStatement()
|
||||
stmt.executeUpdate(sql)
|
||||
conn.close()
|
||||
} catch {
|
||||
case e: Exception => return false
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
private def dataTypeVerify(typeMap: Map[String, String]): Boolean = {
|
||||
for (typeName <- typeMap.values) {
|
||||
val dataType = typeName.split('(')(0).trim
|
||||
|
||||
dataType match {
|
||||
case "bigint" => ""
|
||||
case "binary" => ""
|
||||
case "bit" => ""
|
||||
case "char" => ""
|
||||
case "date" => ""
|
||||
case "datetime" => ""
|
||||
case "datetime2" => ""
|
||||
case "datetimeoffset" => ""
|
||||
case "decimal" => ""
|
||||
case "float" => ""
|
||||
case "image" => return false
|
||||
case "int" => ""
|
||||
case "money" => ""
|
||||
case "nchar" => ""
|
||||
case "ntext" => ""
|
||||
case "numeric" => ""
|
||||
case "nvarchar" => ""
|
||||
case "nvarchar(max)" => ""
|
||||
case "real" => ""
|
||||
case "smalldatetime" => ""
|
||||
case "smallint" => ""
|
||||
case "smallmoney" => ""
|
||||
case "text" => ""
|
||||
case "time" => ""
|
||||
case "timestamp" => ""
|
||||
case "tinyint" => ""
|
||||
case "udt" => return false
|
||||
case "uniqueidentifier" => ""
|
||||
case "varbinary" => ""
|
||||
//case "varbinary(max)" => ""
|
||||
case "varchar" => ""
|
||||
//case "varchar(max)" => ""
|
||||
case "xml" => ""
|
||||
case "sqlvariant" => ""
|
||||
case "geometry" => ""
|
||||
case "geography" => ""
|
||||
case _ => return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
}
|
Загрузка…
Ссылка в новой задаче