翻譯自:Extending Spark Datasource API: write a custom spark datasource
Data Source API
Basic Interfaces
- BaseRelation:展示從DataFrame中產(chǎn)生的底層數(shù)據(jù)源的關(guān)系或者表圈驼。定義如何產(chǎn)生schema信息啄骇。或者說是數(shù)據(jù)源的關(guān)系论颅。
- RelationProvider:獲取參數(shù)列表师坎,返回一個BaseRelation對象狼钮。
- TableScan:對數(shù)據(jù)的schame信息卧抗,進行完整掃描读跷,返回一個沒有過濾的RDD。
- DataSourceRegister:定義數(shù)據(jù)源的簡寫蠢络。
Providers
- SchemaRelationProvider:用戶可以自定義schema信息衰猛。
- CreatableRelationProvider:用戶可以定義從DataFrame中產(chǎn)生新的Relation。
Scans
- PrunedScan:自定義方法刹孔,刪除不需要的列啡省。
- PrunedFilteredScan:自定義方法,刪除不需要的列髓霞,并且對列的值進行過濾卦睹。
- CatalystScan:用于試驗與查詢計劃程序的更直接連接的界面。
Relations
- InsertableRelation:插入數(shù)據(jù)方库。三個假設(shè):1.插入方法提供的數(shù)據(jù)與BaseRelation中定義的Schame信息匹配到结序;2.schema信息不變;3.插入方法中的數(shù)據(jù)都是可以為null的纵潦。
- HadoopFsRelation:Hadoop 文件系統(tǒng)的數(shù)據(jù)源徐鹤。
Output Interfaces
如果使用HadoopFsRelation,會使用到這一塊邀层。
準備工作
數(shù)據(jù)格式
使用文本數(shù)據(jù)作為數(shù)據(jù)源返敬,文件中的數(shù)據(jù)都是以都好分割,行之間以回車為分隔符寥院,數(shù)據(jù)的格式為:
//編號,名稱,性別(1為男性,0為女性),工資劲赠,費用
10001,Alice,0,30000,12000
創(chuàng)建項目
在IDEA中創(chuàng)建一個maven項目,添加相應(yīng)的spark秸谢、scala依賴凛澎。
<properties>
<scala.version>2.11.8</scala.version>
<spark.version>2.2.0</spark.version>
</properties>
<repositories>
<repository>
<id>scala-tools.org</id>
<name>Scala-Tools Maven2 Repository</name>
<url>http://scala-tools.org/repo-releases</url>
</repository>
</repositories>
<pluginRepositories>
<pluginRepository>
<id>scala-tools.org</id>
<name>Scala-Tools Maven2 Repository</name>
<url>http://scala-tools.org/repo-releases</url>
</pluginRepository>
</pluginRepositories>
<dependencies>
<dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-library</artifactId>
<version>${scala.version}</version>
</dependency>
<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
<version>4.4</version>
</dependency>
<dependency>
<groupId>org.specs</groupId>
<artifactId>specs</artifactId>
<version>1.2.5</version>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.apache.spark</groupId>
<artifactId>spark-sql_2.11</artifactId>
<version>${spark.version}</version>
</dependency>
</dependencies>
開始編寫自定義數(shù)據(jù)源
創(chuàng)建Schema信息
為了自定義Schema信息,必須要創(chuàng)建一個DefaultSource的類(源碼規(guī)定估蹄,如果不命名為DefaultSource预厌,會報找不到DefaultSource類的錯誤)。
還需要繼承RelationProvider和SchemaRelationProvider元媚。RelationProvider用來創(chuàng)建數(shù)據(jù)的關(guān)系,SchemaRelationProvider用來明確schema信息苗沧。
在編寫DefaultSource.scala文件時刊棕,如果文件存在的情況下,需要創(chuàng)建相應(yīng)的Relation來根據(jù)路徑讀取文件待逞。
DefaultSource.scala文件代碼:
class DefaultSource
extends RelationProvider
with SchemaRelationProvider {
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
createRelation(sqlContext,parameters,null)
}
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {
val path=parameters.get("path")
path match {
case Some(p) => new TextDataSourceRelation(sqlContext,p,schema)
case _ => throw new IllegalArgumentException("Path is required for custom-datasource format!!")
}
}
}
在編寫Relation時甥角,需要實現(xiàn)BaseRelation來重寫自定數(shù)據(jù)源的schema信息。如果是parquet/csv/json文件识樱,可以直接獲取schema信息嗤无。
然后實現(xiàn)序列化接口震束,為了網(wǎng)絡(luò)傳輸。
TextDataSourceRelation.scala文件的代碼:
class TextDataSourceRelation(override val sqlContext : SQLContext, path : String, userSchema : StructType)
extends BaseRelation
with Serializable {
override def schema: StructType = {
if(userSchema!=null){
userSchema
}else{
StructType(
StructField("id",IntegerType,false) ::
StructField("name",StringType,false) ::
StructField("gender",StringType,false) ::
StructField("salary",LongType,false) ::
StructField("expenses",LongType,false) :: Nil)
}
}
}
根據(jù)上面編寫代碼当犯,可以簡單測試一下是否可以拿到正確的schema信息垢村。
在編寫測試方法時,使用sqlContext.read來讀取文件嚎卫,使用format參數(shù)來指定自定義數(shù)據(jù)源的包路徑嘉栓,使用printSchema()驗證是否可以拿到相應(yīng)的schema信息。
object TestApp extends App {
println("Application Started...")
val conf=new SparkConf().setAppName("spark-custom-datasource")
val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
println("output schema...")
df.printSchema()
println("Application Ended...")
}
輸出的schema信息如下:
output schema...
root
|-- id: integer (nullable = false)
|-- name: string (nullable = false)
|-- gender: string (nullable = false)
|-- salary: long (nullable = false)
|-- expenses: long (nullable = false)
通過輸出的schema拓诸,與自己定義的schema一致侵佃。
讀取數(shù)據(jù)
為了讀取數(shù)據(jù),TextDataSourceRelation需要實現(xiàn)TableScan奠支,實現(xiàn)buildScan()方法馋辈。
這個方法會將數(shù)據(jù)以Row組成的RDD的形式返回數(shù)據(jù),每一個Row表示一行數(shù)據(jù)倍谜。
在讀取文件時迈螟,使用WholeTextFiles根據(jù)指定的路徑來讀取文件,返回的形式為(文件名枢劝,內(nèi)容)井联。
在讀取數(shù)據(jù)之后,然后按照逗號分割數(shù)據(jù),將性別這個字段根據(jù)數(shù)字轉(zhuǎn)換為相應(yīng)的字符串您旁,然后根據(jù)在shema信息烙常,轉(zhuǎn)換為相應(yīng)的類型。
轉(zhuǎn)換的代碼如下:
object Util {
def castTo(value : String, dataType : DataType) ={
dataType match {
case _ : IntegerType => value.toInt
case _ : LongType => value.toLong
case _ : StringType => value
}
}
}
實現(xiàn)TableScan的代碼:
override def buildScan(): RDD[Row] = {
println("TableScan: buildScan called...")
val schemaFields = schema.fields
// Reading the file's content
val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
val rows = rdd.map(fileContent => {
val lines = fileContent.split("\n")
val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
val tmp = data.map(words => words.zipWithIndex.map{
case (value, index) =>
val colName = schemaFields(index).name
Util.castTo(
if (colName.equalsIgnoreCase("gender")) {
if(value.toInt == 1) {
"Male"
} else {
"Female"
}
} else {
value
}, schemaFields(index).dataType)
})
tmp.map(s => Row.fromSeq(s))
})
rows.flatMap(e => e)
}
測試是否可以讀取到數(shù)據(jù)的代碼:
object TestApp extends App {
println("Application Started...")
val conf=new SparkConf().setAppName("spark-custom-datasource")
val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
df.show()
println("Application Ended...")
}
拿到的數(shù)據(jù)為:
+-----+---------------+------+------+--------+
| id| name|gender|salary|expenses|
+-----+---------------+------+------+--------+
|10002| Alice Heady|Female| 20000| 8000|
|10003| Jenny Brown|Female| 30000| 120000|
|10004| Bob Hayden| Male| 40000| 16000|
|10005| Cindy Heady|Female| 50000| 20000|
|10006| Doug Brown| Male| 60000| 24000|
|10007|Carolina Hayden|Female| 70000| 280000|
+-----+---------------+------+------+--------+
寫數(shù)據(jù)
本代碼有兩種編寫方法:自定義格式和Json鹤盒。
繼承CreateTableRelationProvider蚕脏,實現(xiàn)createRelation方法。
class DefaultSource
extends RelationProvider
with SchemaRelationProvider
with CreatableRelationProvider {
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String]): BaseRelation = {
createRelation(sqlContext,parameters,null)
}
override def createRelation(sqlContext: SQLContext, parameters: Map[String, String], schema: StructType): BaseRelation = {
val path=parameters.get("path")
path match {
case Some(p) => new TextDataSourceRelation(sqlContext,p,schema)
case _ => throw new IllegalArgumentException("Path is required for custom-datasource format!!")
}
}
override def createRelation(sqlContext: SQLContext, mode: SaveMode, parameters: Map[String, String], data: DataFrame): BaseRelation = {
val path = parameters.getOrElse("path", "./output/") //can throw an exception/error, it's just for this tutorial
val fsPath = new Path(path)
val fs = fsPath.getFileSystem(sqlContext.sparkContext.hadoopConfiguration)
mode match {
case SaveMode.Append => sys.error("Append mode is not supported by " + this.getClass.getCanonicalName); sys.exit(1)
case SaveMode.Overwrite => fs.delete(fsPath, true)
case SaveMode.ErrorIfExists => sys.error("Given path: " + path + " already exists!!"); sys.exit(1)
case SaveMode.Ignore => sys.exit()
}
val formatName = parameters.getOrElse("format", "customFormat")
formatName match {
case "customFormat" => saveAsCustomFormat(data, path, mode)
case "json" => saveAsJson(data, path, mode)
case _ => throw new IllegalArgumentException(formatName + " is not supported!!!")
}
createRelation(sqlContext, parameters, data.schema)
}
private def saveAsJson(data : DataFrame, path : String, mode: SaveMode): Unit = {
/**
* Here, I am using the dataframe's Api for storing it as json.
* you can have your own apis and ways for saving!!
*/
data.write.mode(mode).json(path)
}
private def saveAsCustomFormat(data : DataFrame, path : String, mode: SaveMode): Unit = {
/**
* Here, I am going to save this as simple text file which has values separated by "|".
* But you can have your own way to store without any restriction.
*/
val customFormatRDD = data.rdd.map(row => {
row.toSeq.map(value => value.toString).mkString("|")
})
customFormatRDD.saveAsTextFile(path)
}
}
測試代碼:
object TestApp extends App {
println("Application Started...")
val conf=new SparkConf().setAppName("spark-custom-datasource")
val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
//save the data
df.write.options(Map("format" -> "customFormat")).mode(SaveMode.Overwrite).format("com.edu.spark.text").save("/Users//Downloads/out_custom/")
df.write.options(Map("format" -> "json")).mode(SaveMode.Overwrite).format("com.edu.spark.text").save("/Users//Downloads/out_json/")
df.write.mode(SaveMode.Overwrite).format("com.edu.spark.text").save("/Users//Downloads/out_none/")
println("Application Ended...")
}
輸出的結(jié)果:
自定義格式:
10002|Alice Heady|Female|20000|8000
10003|Jenny Brown|Female|30000|120000
10004|Bob Hayden|Male|40000|16000
10005|Cindy Heady|Female|50000|20000
10006|Doug Brown|Male|60000|24000
10007|Carolina Hayden|Female|70000|280000
Json格式:
{"id":10002,"name":"Alice Heady","gender":"Female","salary":20000,"expenses":8000}
{"id":10003,"name":"Jenny Brown","gender":"Female","salary":30000,"expenses":120000}
{"id":10004,"name":"Bob Hayden","gender":"Male","salary":40000,"expenses":16000}
{"id":10005,"name":"Cindy Heady","gender":"Female","salary":50000,"expenses":20000}
{"id":10006,"name":"Doug Brown","gender":"Male","salary":60000,"expenses":24000}
{"id":10007,"name":"Carolina Hayden","gender":"Female","salary":70000,"expenses":280000}
修建列
繼承PrunedScan侦锯,實現(xiàn)buildScan方法,只展示需要的項驼鞭。
override def buildScan(requiredColumns: Array[String]): RDD[Row] = {
println("PrunedScan: buildScan called...")
val schemaFields = schema.fields
// Reading the file's content
val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
val rows = rdd.map(fileContent => {
val lines = fileContent.split("\n")
val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
val tmp = data.map(words => words.zipWithIndex.map{
case (value, index) =>
val colName = schemaFields(index).name
val castedValue = Util.castTo(if (colName.equalsIgnoreCase("gender")) {if(value.toInt == 1) "Male" else "Female"} else value,
schemaFields(index).dataType)
if (requiredColumns.contains(colName)) Some(castedValue) else None
})
tmp.map(s => Row.fromSeq(s.filter(_.isDefined).map(value => value.get)))
})
rows.flatMap(e => e)
}
測試代碼:
object TestApp extends App {
println("Application Started...")
val conf=new SparkConf().setAppName("spark-custom-datasource")
val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
//select some specific columns
df.createOrReplaceTempView("test")
spark.sql("select id, name, salary from test").show()
println("Application Ended...")
}
輸出的結(jié)果為:
+-----+---------------+------+
|10002| Alice Heady| 20000|
|10003| Jenny Brown| 30000|
|10004| Bob Hayden| 40000|
|10005| Cindy Heady| 50000|
|10006| Doug Brown| 60000|
|10007|Carolina Hayden| 70000|
+-----+---------------+------+
過濾
繼承PrunedFilterScan,實現(xiàn)buildScan方法,只展示需要的項尺碰。
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
println("PrunedFilterScan: buildScan called...")
println("Filters: ")
filters.foreach(f => println(f.toString))
var customFilters: Map[String, List[CustomFilter]] = Map[String, List[CustomFilter]]()
filters.foreach( f => f match {
case EqualTo(attr, value) =>
println("EqualTo filter is used!!" + "Attribute: " + attr + " Value: " + value)
/**
* as we are implementing only one filter for now, you can think that this below line doesn't mak emuch sense
* because any attribute can be equal to one value at a time. so what's the purpose of storing the same filter
* again if there are.
* but it will be useful when we have more than one filter on the same attribute. Take the below condition
* for example:
* attr > 5 && attr < 10
* so for such cases, it's better to keep a list.
* you can add some more filters in this code and try them. Here, we are implementing only equalTo filter
* for understanding of this concept.
*/
customFilters = customFilters ++ Map(attr -> {
customFilters.getOrElse(attr, List[CustomFilter]()) :+ new CustomFilter(attr, value, "equalTo")
})
case GreaterThan(attr,value) =>
println("GreaterThan Filter is used!!"+ "Attribute: " + attr + " Value: " + value)
customFilters = customFilters ++ Map(attr -> {
customFilters.getOrElse(attr, List[CustomFilter]()) :+ new CustomFilter(attr, value, "greaterThan")
})
case _ => println("filter: " + f.toString + " is not implemented by us!!")
})
val schemaFields = schema.fields
// Reading the file's content
val rdd = sqlContext.sparkContext.wholeTextFiles(path).map(f => f._2)
val rows = rdd.map(file => {
val lines = file.split("\n")
val data = lines.map(line => line.split(",").map(word => word.trim).toSeq)
val filteredData = data.map(s => if (customFilters.nonEmpty) {
var includeInResultSet = true
s.zipWithIndex.foreach {
case (value, index) =>
val attr = schemaFields(index).name
val filtersList = customFilters.getOrElse(attr, List())
if (filtersList.nonEmpty) {
if (CustomFilter.applyFilters(filtersList, value, schema)) {
} else {
includeInResultSet = false
}
}
}
if (includeInResultSet) s else Seq()
} else s)
val tmp = filteredData.filter(_.nonEmpty).map(s => s.zipWithIndex.map {
case (value, index) =>
val colName = schemaFields(index).name
val castedValue = Util.castTo(if (colName.equalsIgnoreCase("gender")) {
if (value.toInt == 1) "Male" else "Female"
} else value,
schemaFields(index).dataType)
if (requiredColumns.contains(colName)) Some(castedValue) else None
})
tmp.map(s => Row.fromSeq(s.filter(_.isDefined).map(value => value.get)))
})
rows.flatMap(e => e)
}
測試代碼:
object TestApp extends App {
println("Application Started...")
val conf=new SparkConf().setAppName("spark-custom-datasource")
val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
val df=spark.sqlContext.read.format("com.edu.spark.text").load("/Users/Downloads/data")
//filter data
df.createOrReplaceTempView("test")
spark.sql("select id,name,gender from test where salary == 50000").show()
println("Application Ended...")
}
輸出的結(jié)果為:
+-----+-----------+------+
| id| name|gender|
+-----+-----------+------+
|10005|Cindy Heady|Female|
+-----+-----------+------+
注冊自定義數(shù)據(jù)源
實現(xiàn)DataSourceRegister的shortName挣棕。
實現(xiàn)代碼如下:
override def shortName(): String = "udftext"
然后在resource目錄下,創(chuàng)建文件為META-INF/services/org.apache.spark.sql.sources.DataSourceRegister,文件內(nèi)容如下:
com.edu.spark.text.DefaultSource
測試代碼如下:
object TestApp extends App {
println("Application Started...")
val conf=new SparkConf().setAppName("spark-custom-datasource")
val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
val df=spark.sqlContext.read.format("udftext").load("/Users/Downloads/data")
df.show()
println("Application Ended...")
}
輸出的結(jié)果為:
+-----+---------------+------+------+--------+
| id| name|gender|salary|expenses|
+-----+---------------+------+------+--------+
|10002| Alice Heady|Female| 20000| 8000|
|10003| Jenny Brown|Female| 30000| 120000|
|10004| Bob Hayden| Male| 40000| 16000|
|10005| Cindy Heady|Female| 50000| 20000|
|10006| Doug Brown| Male| 60000| 24000|
|10007|Carolina Hayden|Female| 70000| 280000|
+-----+---------------+------+------+--------+
編寫相應(yīng)的DataFrameReader來簡寫自定義的數(shù)據(jù)源,代碼如下:
object ReaderObject {
implicit class UDFTextReader(val reader: DataFrameReader) extends AnyVal{
def udftext(path:String) = reader.format("udftext").load(path)
}
}
測試代碼(需要將隱士轉(zhuǎn)換導入相應(yīng)的DataFrameReader):
object TestApp extends App {
println("Application Started...")
val conf=new SparkConf().setAppName("spark-custom-datasource")
val spark=SparkSession.builder().config(conf).master("local[2]").getOrCreate()
val df=spark.sqlContext.read.udftext("/Users/Downloads/data")
df.show()
println("Application Ended...")
}
輸出結(jié)果與上面一致亲桥,不再贅述洛心。
附加CustomFilter.scala代碼
case class CustomFilter(attr : String, value : Any, filter : String)
object CustomFilter {
def applyFilters(filters : List[CustomFilter], value : String, schema : StructType): Boolean = {
var includeInResultSet = true
val schemaFields = schema.fields
val index = schema.fieldIndex(filters.head.attr)
val dataType = schemaFields(index).dataType
val castedValue = Util.castTo(value, dataType)
filters.foreach(f => {
val givenValue = Util.castTo(f.value.toString, dataType)
f.filter match {
case "equalTo" => {
includeInResultSet = castedValue == givenValue
println("custom equalTo filter is used!!")
}
case "greaterThan" => {
includeInResultSet = castedValue.equals(givenValue)
println("custom greaterThan filter is used!!")
}
case _ => throw new UnsupportedOperationException("this filter is not supported!!")
}
})
includeInResultSet
}
}