雖然Spark使用關(guān)系型數(shù)據(jù)庫作為數(shù)據(jù)源的場(chǎng)景并不多尼啡,但是有時(shí)候我們還是希望能夠能夠從MySql等數(shù)據(jù)庫中讀取數(shù)據(jù)琐凭,并封裝成RDD芽隆。Spark官方確實(shí)也提供了這么一個(gè)庫給我們,org.apache.spark.rdd.JdbcRDD统屈。但是這個(gè)庫使用起來讓人覺得很雞肋胚吁,因?yàn)樗恢С謼l件查詢,只支持起止邊界查詢愁憔,這大大限定了它的使用場(chǎng)景腕扶。很多時(shí)候我們需要分析的數(shù)據(jù)不可能單獨(dú)建一個(gè)表,它們往往被混雜在一個(gè)大的表中吨掌,我們會(huì)希望更加精確的找出某一類的數(shù)據(jù)做分析半抱。
查看了一下這個(gè)JdbcRDD的源碼,我們就能明白為什么他只提供起止邊界了思犁。
val stmt = conn.prepareStatement(sql, ResultSet.TYPE_FORWARD_ONLY, ResultSet.CONCUR_READ_ONLY)
val url = conn.getMetaData.getURL
if (url.startsWith("jdbc:mysql:")) {
// setFetchSize(Integer.MIN_VALUE) is a mysql driver specific way to force
// streaming results, rather than pulling entire resultset into memory.
// See the below URL
// dev.mysql.com/doc/connector-j/5.1/en/connector-j-reference-implementation-notes.html
stmt.setFetchSize(Integer.MIN_VALUE)
} else {
stmt.setFetchSize(100)
}
logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")
stmt.setLong(1, part.lower)
stmt.setLong(2, part.upper)
val rs = stmt.executeQuery()
它使用的是游標(biāo)的方式代虾,conn.prepareStatement(sql, type, concurrency)进肯,因此傳入的參數(shù)只能是這個(gè)分區(qū)的起始編號(hào)part.lower和這個(gè)分區(qū)的終止編號(hào)part.upper激蹲。我查了半天資料,也不知道這種方式該如何將條件傳給這個(gè)stmt 江掩,有點(diǎn)難受学辱。索性也不嘗試了乘瓤,也不考慮兼容其他類型的數(shù)據(jù)庫,只考慮mysql數(shù)據(jù)庫的話策泣,把游標(biāo)這種方式給去了衙傀,這樣使用limit總能給它查出來吧。
以下是具體實(shí)現(xiàn)萨咕,
重寫的JdbcRDD:
package JdbcRDD
import java.sql.{Connection, ResultSet}
import java.util.ArrayList
import scala.reflect.ClassTag
import org.apache.spark.{Partition, SparkContext, TaskContext}
import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
private class JdbcPartition(idx: Int, val lower: Long, val upper: Long ,val params:ArrayList[Any]) extends Partition {
override def index: Int = idx
}
class JdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
lowerBound: Long,
upperBound: Long,
params: ArrayList[Any],
numPartitions: Int,
mapRow: (ResultSet) => T = JdbcRDD.resultSetToObjectArray _)
extends RDD[T](sc, Nil) with Logging {
override def getPartitions: Array[Partition] = {
// bounds are inclusive, hence the + 1 here and - 1 on end
val length = BigInt(1) + upperBound - lowerBound
(0 until numPartitions).map { i =>
val start = lowerBound + ((i * length) / numPartitions)
val end = lowerBound + (((i + 1) * length) / numPartitions) - 1
new JdbcPartition(i, start.toLong, end.toLong,params)
}.toArray
}
override def compute(thePart: Partition, context: TaskContext): Iterator[T] = new NextIterator[T]
{
context.addTaskCompletionListener{ context => closeIfNeeded() }
val part = thePart.asInstanceOf[JdbcPartition]
val conn = getConnection()
//直接采用我們常用的預(yù)處理方式
val stmt = conn.prepareStatement(sql)
val url = conn.getMetaData.getURL
if (url.startsWith("jdbc:mysql:")) {
stmt.setFetchSize(Integer.MIN_VALUE)
} else {
return null
}
logInfo(s"statement fetch size set to: ${stmt.getFetchSize}")
//傳參
val params = part.params
val paramsSize = params.size()
if(params!=null){
for(i <- 1 to paramsSize){
val param = params.get(i-1)
param match {
case param:String => stmt.setString(i,param)
case param:Int => stmt.setInt(i,param)
case param:Boolean => stmt.setBoolean(i,param)
case param:Double => stmt.setDouble(i,param)
case param:Float => stmt.setFloat(i,param)
case _=> {
println("type is fault")
}
}
}
}
//限定該分區(qū)查詢起始偏移量和條數(shù)
stmt.setLong(paramsSize+1, part.lower)
stmt.setLong(paramsSize+2, part.upper-part.lower+1)
val rs = stmt.executeQuery()
override def getNext(): T = {
if (rs.next()) {
mapRow(rs)
} else {
finished = true
null.asInstanceOf[T]
}
}
override def close() {
try {
if (null != rs) {
rs.close()
}
} catch {
case e: Exception => logWarning("Exception closing resultset", e)
}
try {
if (null != stmt) {
stmt.close()
}
} catch {
case e: Exception => logWarning("Exception closing statement", e)
}
try {
if (null != conn) {
conn.close()
}
logInfo("closed connection")
} catch {
case e: Exception => logWarning("Exception closing connection", e)
}
}
}
}
object JdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
trait ConnectionFactory extends Serializable {
@throws[Exception]
def getConnection: Connection
}
def create[T](
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
lowerBound: Long,
upperBound: Long,
params: ArrayList[Any],
numPartitions: Int,
mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
val jdbcRDD = new JdbcRDD[T](
sc.sc,
() => connectionFactory.getConnection,
sql,
lowerBound,
upperBound,
params,
numPartitions,
(resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
new JavaRDD[T](jdbcRDD)( fakeClassTag)
}
def create(
sc: JavaSparkContext,
connectionFactory: ConnectionFactory,
sql: String,
lowerBound: Long,
upperBound: Long,
params: ArrayList[Any],
numPartitions: Int): JavaRDD[Array[Object]] = {
val mapRow = new JFunction[ResultSet, Array[Object]] {
override def call(resultSet: ResultSet): Array[Object] = {
resultSetToObjectArray(resultSet)
}
}
create(sc, connectionFactory, sql, lowerBound, upperBound, params,numPartitions, mapRow)
}
private def fakeClassTag[T]: ClassTag[T] = ClassTag.AnyRef.asInstanceOf[ClassTag[T]]
}
接下來是測(cè)試代碼:
package JdbcRDD
import java.sql.{DriverManager, ResultSet}
import java.util
import org.apache.spark.SparkContext
object JdbcRDDTest {
def main(args: Array[String]) {
//val conf = new SparkConf().setAppName("spark_mysql").setMaster("local")
val sc = new SparkContext("local[2]","spark_mysql")
def createConnection() = {
Class.forName("com.mysql.jdbc.Driver").newInstance()
DriverManager.getConnection("jdbc:mysql://localhost:3306/transportation", "root", "pass")
}
def extractValues(r: ResultSet) = {
(r.getString(1), r.getString(2))
}
val params = new util.ArrayList[Any]
params.add(100)//傳參
params.add(7)
val data = new JdbcRDD(sc, createConnection, "SELECT * FROM login_log where id<=? and user_id=? limit ?,?", lowerBound = 1, upperBound =20,params=params, numPartitions = 5, mapRow = extractValues)
data.cache()
println(data.collect.length)
println(data.collect().toList)
sc.stop()
}
}
測(cè)試結(jié)果:
可以看出统抬,重寫這個(gè)JdbcRDD后我們可以條件查詢某一個(gè)表,也可以同時(shí)限定查詢條數(shù)危队,這給我們用Spark分析Mysql中的數(shù)據(jù)提供了方便聪建,我們不需要先將需要的數(shù)據(jù)濾出來再進(jìn)行分析。當(dāng)然茫陆,這個(gè)demo寫的比較粗糙金麸,只是提供這么一種方法的演示,后期還可以稍加修改簿盅。