最近研究了一下時(shí)間序列預(yù)測(cè)的使用,網(wǎng)上找了大部分的資源建钥,都是使用python來(lái)實(shí)現(xiàn)的藤韵,使用python來(lái)實(shí)現(xiàn)雖然能滿足大部分的需求,但是python有一點(diǎn)缺點(diǎn)按就是只能使用一臺(tái)計(jì)算資源進(jìn)行計(jì)算锦针,如果數(shù)據(jù)量大的時(shí)候荠察,就有可能不能勝任置蜀,雖然這種情況很少,但是還是有可能會(huì)發(fā)生悉盆,因此就查了一下spark有沒(méi)有這方面的資料盯荤,沒(méi)想到還真的有,使用spark集群進(jìn)行計(jì)算速度方面提升明顯焕盟。
項(xiàng)目接地址:https://github.com/sryza/spark-timeseries
首先非常感謝這位博主秋秤,我是在學(xué)習(xí)了他的代碼之下才能更好的理解spark-timeseries的使用。
博客鏈接:http://blog.csdn.net/qq_30232405/article/details/70622400
下面是我對(duì)代碼的改進(jìn),主要是調(diào)整的是時(shí)間類型的通用性與arima模型能自定義pdq參數(shù)等脚翘,能通用大部分類型的時(shí)間灼卢。
TimeFormatUtils.java
import java.text.ParseException;
import java.text.SimpleDateFormat;
import java.util.HashMap;
import java.util.regex.Pattern;
public class TimeFormatUtils {
/**
* 獲取時(shí)間類型格式
*
* @param timeStr
* @return
*/
public static String getDateType(String timeStr) {
HashMap<String, String> dateRegFormat = new HashMap<String, String>();
dateRegFormat.put("^\\d{4}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D+\\d{1,2}\\D*$", "yyyy-MM-dd HH:mm:ss");//2014年3月12日 13時(shí)5分34秒,2014-03-12 12:05:34来农,2014/3/12 12:5:34
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH:mm");//2014-03-12 12:05
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd HH");//2014-03-12 12
dateRegFormat.put("^\\d{4}\\D+\\d{2}\\D+\\d{2}$", "yyyy-MM-dd");//2014-03-12
dateRegFormat.put("^\\d{4}\\D+\\d{2}$", "yyyy-MM");//2014-03
dateRegFormat.put("^\\d{4}$", "yyyy");//2014
dateRegFormat.put("^\\d{14}$", "yyyyMMddHHmmss");//20140312120534
dateRegFormat.put("^\\d{12}$", "yyyyMMddHHmm");//201403121205
dateRegFormat.put("^\\d{10}$", "yyyyMMddHH");//2014031212
dateRegFormat.put("^\\d{8}$", "yyyyMMdd");//20140312
dateRegFormat.put("^\\d{6}$", "yyyyMM");//201403
try {
for (String key : dateRegFormat.keySet()) {
if (Pattern.compile(key).matcher(timeStr).matches()) {
String formater = "";
if (timeStr.contains("/"))
return dateRegFormat.get(key).replaceAll("-", "/");
else
return dateRegFormat.get(key);
}
}
} catch (Exception e) {
System.err.println("-----------------日期格式無(wú)效:" + timeStr);
e.printStackTrace();
}
return null;
}
public static String fromatData(String time, SimpleDateFormat format) {
try {
SimpleDateFormat formatter = new SimpleDateFormat("yyyy-MM-dd HH:mm:ss");
return formatter.format(format.parse(time));
} catch (ParseException e) {
e.printStackTrace();
}
return null;
}
}
TimeSeriesTrain.scala
import java.sql.Timestamp
import java.text.SimpleDateFormat
import java.time.{ZoneId, ZonedDateTime}
import com.cloudera.sparkts._
import com.sendi.TimeSeries.Util.TimeFormatUtils
import org.apache.log4j.{Level, Logger}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types._
import org.apache.spark.sql.{DataFrame, Row, SparkSession}
/**
* 時(shí)間序列模型time-series的建立
*/
object TimeSeriesTrain {
/**
* 總方法調(diào)用
*/
def timeSeries(args: Array[String]) {
args.foreach(println)
Logger.getLogger("org.apache.spark").setLevel(Level.WARN)
Logger.getLogger("org.eclipse.jetty.server").setLevel(Level.OFF)
/**
* 1鞋真、初始化spark環(huán)境
*/
val sparkSession = SparkSession.builder
.master("local[4]").appName("SparkTest")
.enableHiveSupport() //創(chuàng)建支持HiveContext;
.getOrCreate()
/**
* 2、初始化參數(shù)
*/
//hive中的數(shù)據(jù)庫(kù)名字
val databaseTableName = args(0)
//輸入的列名必須是time data
val hiveColumnName = List(args(1).toString.split(","): _*)
//開(kāi)始與結(jié)束時(shí)間
val startTime = args(2)
val endTime = args(3)
//獲取時(shí)間類型
val sdf = new SimpleDateFormat(TimeFormatUtils.getDateType(startTime))
//時(shí)間跨度
val timeSpanType = args(4)
val timeSpan = args(5).toInt
//預(yù)測(cè)后面N個(gè)值
val predictedN = args(6).toInt
//存放的表名字
val outputTableName = args(7)
var listPDQ: List[String] = List("")
var period = 0
var holtWintersModelType = ""
//選擇模型(holtwinters或者是arima)
val modelName = args(8)
//根據(jù)不同的類型賦值不同的參數(shù)
if (modelName.equals("arima")) {
listPDQ = List(args(9).toString.split(","): _*)
} else {
//季節(jié)性參數(shù)(12或者4)
period = args(9).toInt
//holtWinters選擇模型:additive(加法模型)、Multiplicative(乘法模型)
holtWintersModelType = args(10)
}
/**
* 3、 讀取數(shù)據(jù)源螺男,最終轉(zhuǎn)換成 {time key data} 這種類型的RDD格式
*/
val timeDataKeyDf = readHiveData(sparkSession, databaseTableName, hiveColumnName)
val zonedDateDataDf = timeChangeToDate(sparkSession, timeDataKeyDf, hiveColumnName, startTime, sdf)
/**
* 4、創(chuàng)建數(shù)據(jù)中時(shí)間的跨度(Create an daily DateTimeIndex):開(kāi)始日期+結(jié)束日期+遞增數(shù)
* 日期的格式要與數(shù)據(jù)庫(kù)中time數(shù)據(jù)的格式一樣
*/
val dtIndex = getTimeSpan(startTime, endTime, timeSpanType, timeSpan, sdf)
/**
* 5檩互、創(chuàng)建訓(xùn)練數(shù)據(jù)
*/
val trainTsrdd = TimeSeriesRDD.timeSeriesRDDFromObservations(dtIndex, zonedDateDataDf,
hiveColumnName(0), hiveColumnName(0) + "Key", hiveColumnName(1))
trainTsrdd.cache()
//填充缺失值
val filledTrainTsrdd = trainTsrdd.fill("linear")
/**
* 6、建立模型對(duì)象咨演,并使用訓(xùn)練數(shù)據(jù)進(jìn)行訓(xùn)練
*/
val timeSeriesKeyModel = new TimeSeriesKeyModel(predictedN, outputTableName)
var forecastValue: RDD[(String, Vector)] = sparkSession.sparkContext.parallelize(Seq(("", Vectors.dense(1))))
//選擇模型
modelName match {
case "arima" => {
//創(chuàng)建和訓(xùn)練arima模型
val (forecast, coefficients) = timeSeriesKeyModel.arimaModelTrainKey(filledTrainTsrdd, listPDQ)
//Arima模型評(píng)估參數(shù)的保存
forecastValue = forecast
timeSeriesKeyModel.arimaModelKeyEvaluationSave(sparkSession, coefficients, forecast)
}
case "holtwinters" => {
//創(chuàng)建和訓(xùn)練HoltWinters模型(季節(jié)性模型)
val (forecast, sse) = timeSeriesKeyModel.holtWintersModelTrainKey(filledTrainTsrdd, period, holtWintersModelType)
//HoltWinters模型評(píng)估參數(shù)的保存
forecastValue = forecast
timeSeriesKeyModel.holtWintersModelKeyEvaluationSave(sparkSession, sse, forecast)
}
case _ => throw new UnsupportedOperationException("Currently only supports 'ariam' and 'holtwinters")
}
/**
* 7闸昨、合并實(shí)際值和預(yù)測(cè)值,并加上日期,形成dataframe(Date,Data)薄风,并保存
*/
timeSeriesKeyModel.actualForcastDateKeySaveInHive(sparkSession, filledTrainTsrdd, forecastValue, predictedN, startTime,
endTime, timeSpanType, timeSpan, sdf, hiveColumnName)
}
/**
* 讀取hive中的數(shù)據(jù)饵较,并對(duì)其進(jìn)行處理操作,返回 time data key
*
* @param sparkSession
* @param databaseTableName
* @param hiveColumnName
*/
def readHiveData(sparkSession: SparkSession, databaseTableName: String, hiveColumnName: List[String]): DataFrame = {
//read the data form the hive where的作用是取出字段為time的列
var hiveDataDf = sparkSession.sql("select * from " + databaseTableName + " where " + hiveColumnName(0) + " !='" + hiveColumnName(0) + "'")
.select(hiveColumnName.head, hiveColumnName.tail: _*)
//去除空值
hiveDataDf = hiveDataDf.filter(hiveColumnName(1) + " != ''")
//In hiveDataDF:increase a new column.This column's name is hiveColumnName(0)+"Key",it's value is 0.
//timeDataKeyDf : time data timeKey column
val timeDataKeyDf = hiveDataDf.withColumn(hiveColumnName(0) + "Key", hiveDataDf(hiveColumnName(1)) * 0.toString)
.select(hiveColumnName(0), hiveColumnName(1), hiveColumnName(0) + "Key")
timeDataKeyDf
}
/**
* 把數(shù)據(jù)中的“time”列轉(zhuǎn)換成固定時(shí)間格式:ZonedDateTime(such as 2007-12-03T10:15:30+01:00 Europe/Paris.)
*
* @param sparkSession
* @param timeDataKeyDf
* @param hiveColumnName
* @param startTime
* @param sdf
* @return
*/
def timeChangeToDate(sparkSession: SparkSession, timeDataKeyDf: DataFrame, hiveColumnName: List[String], startTime: String,
sdf: SimpleDateFormat): DataFrame = {
var rowRDD: RDD[Row] = sparkSession.sparkContext.parallelize(Seq(Row(""), Row("")))
rowRDD = timeDataKeyDf.rdd.map { row =>
row match {
case Row(time, data, key) => {
val date = sdf.parse(time.toString)
val timestamp = new Timestamp(date.getTime)
Row(timestamp, key.toString, data.toString.toDouble)
}
}
}
//根據(jù)模式字符串生成模式遭赂,轉(zhuǎn)化成dataframe格式
var field = Seq(
StructField(hiveColumnName(0), TimestampType, true),
StructField(hiveColumnName(0) + "Key", StringType, true),
StructField(hiveColumnName(1), DoubleType, true))
val schema = StructType(field)
val zonedDateDataDf = sparkSession.createDataFrame(rowRDD, schema)
return zonedDateDataDf
}
/**
* 獲取時(shí)間區(qū)間與時(shí)間跨度
*
* @param timeSpanType
* @param timeSpan
* @param sdf
* @param startTime
* @param endTime
*/
def getTimeSpan(startTime: String, endTime: String, timeSpanType: String, timeSpan: Int, sdf: SimpleDateFormat): UniformDateTimeIndex = {
val start = TimeFormatUtils.fromatData(startTime, sdf)
val end = TimeFormatUtils.fromatData(endTime, sdf)
val zone = ZoneId.systemDefault()
val frequency = timeSpanType match {
case "year" => new YearFrequency(timeSpan);
case "month" => new MonthFrequency(timeSpan);
case "day" => new DayFrequency(timeSpan);
case "hour" => new HourFrequency(timeSpan);
case "minute" => new MinuteFrequency(timeSpan);
}
val dtIndex: UniformDateTimeIndex = DateTimeIndex.uniformFromInterval(
ZonedDateTime.of(start.substring(0, 4).toInt, start.substring(5, 7).toInt, start.substring(8, 10).toInt,
start.substring(11, 13).toInt, start.substring(14, 16).toInt, 0, 0, zone),
ZonedDateTime.of(end.substring(0, 4).toInt, end.substring(5, 7).toInt, end.substring(8, 10).toInt,
end.substring(11, 13).toInt, end.substring(14, 16).toInt, 0, 0, zone),
frequency)
return dtIndex
}
}
TimeSeriesKeyModel.scala
import java.text.SimpleDateFormat
import java.util.Calendar
import com.cloudera.sparkts.TimeSeriesRDD
import com.cloudera.sparkts.models.{ARIMA}
import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.stat.Statistics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{Row, SaveMode, SparkSession}
import org.apache.spark.sql.types.{StringType, StructField, StructType}
import scala.collection.mutable.ArrayBuffer
/**
* 時(shí)間序列模型(處理的數(shù)據(jù)多一個(gè)key列)
* Created by llq on 2017/5/3.
*/
class TimeSeriesKeyModel {
//預(yù)測(cè)后面N個(gè)值
private var predictedN = 1
//存放的表名字
private var outputTableName = "time_series.timeseries_output"
def this(predictedN: Int, outputTableName: String) {
this()
this.predictedN = predictedN
this.outputTableName = outputTableName
}
/**
* 實(shí)現(xiàn)Arima模型告抄,處理數(shù)據(jù)是多一個(gè)key列
*
* @param trainTsrdd
* @return
*/
def arimaModelTrainKey(trainTsrdd: TimeSeriesRDD[String], listPDQ: List[String]): (RDD[(String, Vector)], RDD[(String, (String, (String, String, String), String, String))]) = {
/** *參數(shù)設(shè)置 ******/
val predictedN = this.predictedN
/** *創(chuàng)建arima模型 ***/
//創(chuàng)建和訓(xùn)練arima模型.其RDD格式為(ArimaModel,Vector)
val arimaAndVectorRdd = trainTsrdd.map { line =>
line match {
case (key, denseVector) => {
if (listPDQ.size >= 3) {
(key, ARIMA.fitModel(listPDQ(0).toInt, listPDQ(1).toInt, listPDQ(2).toInt, denseVector), denseVector)
} else {
(key, ARIMA.autoFit(denseVector), denseVector)
}
}
}
}
/** 參數(shù)輸出:p,d,q的實(shí)際值和其系數(shù)值、最大似然估計(jì)值嵌牺、aic值 **/
val coefficients = arimaAndVectorRdd.map { line =>
line match {
case (key, arimaModel, denseVector) => {
(key, (arimaModel.coefficients.mkString(","),
(arimaModel.p.toString,
arimaModel.d.toString,
arimaModel.q.toString),
arimaModel.logLikelihoodCSS(denseVector).toString,
arimaModel.approxAIC(denseVector).toString))
}
}
}
coefficients.collect().map {
_ match {
case (key, (coefficients, (p, d, q), logLikelihood, aic)) =>
println(key + " coefficients:" + coefficients + "=>" + "(p=" + p + ",d=" + d + ",q=" + q + ")")
}
}
/** *預(yù)測(cè)出后N個(gè)的值 *****/
val forecast = arimaAndVectorRdd.map { row =>
row match {
case (key, arimaModel, denseVector) => {
(key, arimaModel.forecast(denseVector, predictedN))
}
}
}
//取出預(yù)測(cè)值
val forecastValue = forecast.map {
_ match {
case (key, value) => {
val partArray = value.toArray.mkString(",").split(",")
var forecastArrayBuffer = new ArrayBuffer[Double]()
var i = partArray.length - predictedN
while (i < partArray.length) {
forecastArrayBuffer += partArray(i).toDouble
i = i + 1
}
(key, Vectors.dense(forecastArrayBuffer.toArray))
}
}
}
println("Arima forecast of next " + predictedN + " observations:")
forecastValue.foreach(println)
return (forecastValue, coefficients)
}
/**
* Arima模型評(píng)估參數(shù)的保存
* coefficients、(p龄糊、d逆粹、q)、logLikelihoodCSS炫惩、Aic僻弹、mean、variance他嚷、standard_deviation蹋绽、max芭毙、min、range卸耘、count
*
* @param sparkSession
* @param coefficients
* @param forecastValue
*/
def arimaModelKeyEvaluationSave(sparkSession: SparkSession, coefficients: RDD[(String, (String, (String, String, String), String, String))], forecastValue: RDD[(String, Vector)]): Unit = {
/** 把vector轉(zhuǎn)置 **/
val forecastRdd = forecastValue.map {
_ match {
case (key, forecast) => forecast.toArray
}
}
// Split the matrix into one number per line.
val byColumnAndRow = forecastRdd.zipWithIndex.flatMap {
case (row, rowIndex) => row.zipWithIndex.map {
case (number, columnIndex) => columnIndex -> (rowIndex, number)
}
}
// Build up the transposed matrix. Group and sort by column index first.
val byColumn = byColumnAndRow.groupByKey.sortByKey().values
// Then sort by row index.
val transposed = byColumn.map {
indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2)
}
val summary = Statistics.colStats(transposed.map(value => Vectors.dense(value(0))))
/** 統(tǒng)計(jì)求出預(yù)測(cè)值的均值退敦、方差、標(biāo)準(zhǔn)差蚣抗、最大值侈百、最小值、極差翰铡、數(shù)量等;合并模型評(píng)估數(shù)據(jù)+統(tǒng)計(jì)值 **/
//評(píng)估模型的參數(shù)+預(yù)測(cè)出來(lái)數(shù)據(jù)的統(tǒng)計(jì)值
val evaluation = coefficients.join(forecastValue.map {
_ match {
case (key, forecast) => {
(key, (summary.mean.toArray(0).toString,
summary.variance.toArray(0).toString,
math.sqrt(summary.variance.toArray(0)).toString,
summary.max.toArray(0).toString,
summary.min.toArray(0).toString,
(summary.max.toArray(0) - summary.min.toArray(0)).toString,
summary.count.toString))
}
}
})
val evaluationRddRow = evaluation.map {
_ match {
case (key, ((coefficients, pdq, logLikelihoodCSS, aic), (mean, variance, standardDeviation, max, min, range, count))) => {
Row(coefficients, pdq.toString, logLikelihoodCSS, aic, mean, variance, standardDeviation, max, min, range, count)
}
}
}
//形成評(píng)估dataframe
val schemaString = "coefficients,pdq,logLikelihoodCSS,aic,mean,variance,standardDeviation,max,min,range,count"
val schema = StructType(schemaString.split(",").map(fileName => StructField(fileName, StringType, true)))
val evaluationDf = sparkSession.createDataFrame(evaluationRddRow, schema)
println("Evaluation in Arima:")
evaluationDf.show()
/**
* 把這份數(shù)據(jù)保存到hive與db中
*/
evaluationDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName + "_arima_evaluation")
}
/**
* 實(shí)現(xiàn)holtwinters模型钝域,處理的數(shù)據(jù)多一個(gè)key列
*
* @param trainTsrdd
* @param period
* @param holtWintersModelType
* @return
*/
def holtWintersModelTrainKey(trainTsrdd: TimeSeriesRDD[String], period: Int, holtWintersModelType: String): (RDD[(String, Vector)], RDD[(String, Double)]) = {
/** *參數(shù)設(shè)置 ******/
//往后預(yù)測(cè)多少個(gè)值
val predictedN = this.predictedN
/** *創(chuàng)建HoltWinters模型 ***/
//創(chuàng)建和訓(xùn)練HoltWinters模型.其RDD格式為(HoltWintersModel,Vector)
val holtWintersAndVectorRdd = trainTsrdd.map { line =>
line match {
case (key, denseVector) =>
(key, HoltWinters.fitModel(denseVector, period, holtWintersModelType), denseVector)
}
}
/** *預(yù)測(cè)出后N個(gè)的值 *****/
//構(gòu)成N個(gè)預(yù)測(cè)值向量,之后導(dǎo)入到holtWinters的forcast方法中
val predictedArrayBuffer = new ArrayBuffer[Double]()
var i = 0
while (i < predictedN) {
predictedArrayBuffer += i
i = i + 1
}
val predictedVectors = Vectors.dense(predictedArrayBuffer.toArray)
//預(yù)測(cè)
val forecast = holtWintersAndVectorRdd.map { row =>
row match {
case (key, holtWintersModel, denseVector) => {
(key, holtWintersModel.forecast(denseVector, predictedVectors))
}
}
}
println("HoltWinters forecast of next " + predictedN + " observations:")
forecast.foreach(println)
/** holtWinters模型評(píng)估度量:SSE和方差 **/
val sse = holtWintersAndVectorRdd.map { row =>
row match {
case (key, holtWintersModel, denseVector) => {
(key, holtWintersModel.sse(denseVector))
}
}
}
return (forecast, sse)
}
/**
* HoltWinters模型評(píng)估參數(shù)的保存
* sse锭魔、mean例证、variance、standard_deviation迷捧、max织咧、min、range党涕、count
*
* @param sparkSession
* @param sse
* @param forecastValue
*/
def holtWintersModelKeyEvaluationSave(sparkSession: SparkSession, sse: RDD[(String, Double)], forecastValue: RDD[(String, Vector)]): Unit = {
/** 把vector轉(zhuǎn)置 **/
val forecastRdd = forecastValue.map {
_ match {
case (key, forecast) => forecast.toArray
}
}
// Split the matrix into one number per line.
val byColumnAndRow = forecastRdd.zipWithIndex.flatMap {
case (row, rowIndex) => row.zipWithIndex.map {
case (number, columnIndex) => columnIndex -> (rowIndex, number)
}
}
// Build up the transposed matrix. Group and sort by column index first.
val byColumn = byColumnAndRow.groupByKey.sortByKey().values
// Then sort by row index.
val transposed = byColumn.map {
indexedRow => indexedRow.toSeq.sortBy(_._1).map(_._2)
}
val summary = Statistics.colStats(transposed.map(value => Vectors.dense(value(0))))
/** 統(tǒng)計(jì)求出預(yù)測(cè)值的均值烦感、方差、標(biāo)準(zhǔn)差膛堤、最大值手趣、最小值、極差肥荔、數(shù)量等;合并模型評(píng)估數(shù)據(jù)+統(tǒng)計(jì)值 **/
//評(píng)估模型的參數(shù)+預(yù)測(cè)出來(lái)數(shù)據(jù)的統(tǒng)計(jì)值
val evaluation = sse.join(forecastValue.map {
_ match {
case (key, forecast) => {
(key, (summary.mean.toArray(0).toString,
summary.variance.toArray(0).toString,
math.sqrt(summary.variance.toArray(0)).toString,
summary.max.toArray(0).toString,
summary.min.toArray(0).toString,
(summary.max.toArray(0) - summary.min.toArray(0)).toString,
summary.count.toString))
}
}
})
val evaluationRddRow = evaluation.map {
_ match {
case (key, (sse, (mean, variance, standardDeviation, max, min, range, count))) => {
Row(sse.toString, mean, variance, standardDeviation, max, min, range, count)
}
}
}
//形成評(píng)估dataframe
val schemaString = "sse,mean,variance,standardDeviation,max,min,range,count"
val schema = StructType(schemaString.split(",").map(fileName => StructField(fileName, StringType, true)))
val evaluationDf = sparkSession.createDataFrame(evaluationRddRow, schema)
println("Evaluation in HoltWinters:")
evaluationDf.show()
/**
* 存入到hive與db中
*/
evaluationDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName + "_holtwinters_evaluation")
}
/**
* 把信息存儲(chǔ)到hive中
*
* @param sparkSession
* @param dateDataRdd
* @param hiveColumnName
*/
private def keySaveInHive(sparkSession: SparkSession, dateDataRdd: RDD[Row], hiveColumnName: List[String]): Unit = {
//把dateData轉(zhuǎn)換成dataframe
val schemaString = hiveColumnName(0) + " " + hiveColumnName(1)
val schema = StructType(schemaString.split(" ")
.map(fieldName => StructField(fieldName, StringType, true)))
val dateDataDf = sparkSession.createDataFrame(dateDataRdd, schema)
//dateDataDf存進(jìn)hive中
dateDataDf.write.mode(SaveMode.Overwrite).saveAsTable(outputTableName)
}
/**
* 合并實(shí)際值和預(yù)測(cè)值绿渣,并加上日期,形成dataframe(Date,Data)
*
* @param sparkSession
* @param trainTsrdd
* @param forecastValue
* @param predictedN
* @param startTime
* @param endTime
* @param timeSpanType
* @param timeSpan
* @param sdf
* @param hiveColumnName
*/
def actualForcastDateKeySaveInHive(sparkSession: SparkSession, trainTsrdd: TimeSeriesRDD[String], forecastValue: RDD[(String, Vector)],
predictedN: Int, startTime: String, endTime: String, timeSpanType: String, timeSpan: Int,
sdf: SimpleDateFormat, hiveColumnName: List[String]): Unit = {
//在真實(shí)值后面追加預(yù)測(cè)值
val actualAndForcastRdd = trainTsrdd.map {
_ match {
case (key, actualValue) => (key, actualValue.toArray.mkString(","))
}
}.join(forecastValue.map {
_ match {
case (key, forecastValue) => (key, forecastValue.toArray.mkString(","))
}
})
//獲取從開(kāi)始預(yù)測(cè)到預(yù)測(cè)后的時(shí)間,轉(zhuǎn)成RDD形式
val dateArray = productStartDatePredictDate(predictedN, timeSpanType, timeSpan, sdf, startTime, endTime)
val dateRdd = sparkSession.sparkContext.parallelize(dateArray.toArray.mkString(",").split(",").map(date => (date)))
//合并日期和數(shù)據(jù)值,形成RDD[Row]+keyName
val actualAndForcastArray = actualAndForcastRdd.collect()
for (i <- 0 until actualAndForcastArray.length) {
val dateDataRdd = actualAndForcastArray(i) match {
case (key, value) => {
val actualAndForcast = sparkSession.sparkContext.parallelize(value.toString().split(",")
.map(data => {
data.replaceAll("\\(", "").replaceAll("\\)", "")
}))
dateRdd.zip(actualAndForcast).map {
_ match {
case (date, data) => Row(date, data)
}
}
}
}
//保存信息
if (dateDataRdd.collect()(0).toString() != "[1]") {
keySaveInHive(sparkSession, dateDataRdd, hiveColumnName)
}
}
}
/**
* 批量生成日期燕耿,時(shí)間段為:訓(xùn)練數(shù)據(jù)的開(kāi)始到預(yù)測(cè)的結(jié)束
*
* @param predictedN
* @param timeSpanType
* @param timeSpan
* @param format
* @param startTime
* @param endTime
* @return
*/
def productStartDatePredictDate(predictedN: Int, timeSpanType: String, timeSpan: Int,
format: SimpleDateFormat, startTime: String, endTime: String): ArrayBuffer[String] = {
//形成開(kāi)始start到預(yù)測(cè)predicted的日期
val cal1 = Calendar.getInstance()
cal1.setTime(format.parse(startTime))
val cal2 = Calendar.getInstance()
cal2.setTime(format.parse(endTime))
/**
* 獲取時(shí)間差
*/
var field = 1
var diff: Long = 0
timeSpanType match {
case "year" => {
field = Calendar.YEAR
diff = (cal2.getTime.getYear() - cal1.getTime.getYear()) / timeSpan + predictedN;
}
case "month" => {
field = Calendar.MONTH
diff = ((cal2.getTime.getYear() - cal1.getTime.getYear()) * 12 + (cal2.getTime.getMonth() - cal1.getTime.getMonth())) / timeSpan + predictedN
}
case "day" => {
field = Calendar.DATE
diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60 * 60 * 24) / timeSpan + predictedN
}
case "hour" => {
field = Calendar.HOUR
diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60 * 60) / timeSpan + predictedN
}
case "minute" => {
field = Calendar.MINUTE
diff = (cal2.getTimeInMillis - cal1.getTimeInMillis) / (1000 * 60) / timeSpan + predictedN;
}
}
var iDiff = 0L;
var dateArrayBuffer = new ArrayBuffer[String]()
while (iDiff <= diff) {
//保存日期
dateArrayBuffer += format.format(cal1.getTime)
cal1.add(field, timeSpan)
iDiff = iDiff + 1;
}
dateArrayBuffer
}
}