1. 程式人生 > >spark 將DataFrame所有的列型別改為double

spark 將DataFrame所有的列型別改為double

前言

由於spark機器學習要求輸入的DataFrame型別為數值型別,所以如果原始資料讀進來的列為string型別,需要一一轉化,而如果列很多的情況下一個轉化很麻煩,所以能不能一個迴圈或者一個函式去解決呢。

1、單列轉化方法

import org.apache.spark.sql.types._
val data = Array(("1", "2", "3", "4", "5"), ("6", "7", "8", "9", "10"))
val df = spark.createDataFrame(data).toDF("col1", "col2", "col3", "col4"
, "col5") import org.apache.spark.sql.functions._ df.select(col("col1").cast(DoubleType)).show()
+----+
|col1|
+----+
| 1.0|
| 6.0|
+----+

2、迴圈轉變

然後就想能不能用這個方法迴圈把每一列轉成double,但沒想到怎麼實現,可以用withColumn迴圈實現。

val colNames = df.columns

var df1 = df
for (colName <- colNames) {
  df1 = df1.withColumn(colName, col(colName).cast(DoubleType))
}
df1.show()
+----+----+----+----+----+
|col1|col2|col3|col4|col5|
+----+----+----+----+----+
| 1.0| 2.0| 3.0| 4.0| 5.0|
| 6.0| 7.0| 8.0| 9.0|10.0|
+----+----+----+----+----+

3、通過:_*

但是上面這個方法效率比較低,然後問了一下別人,發現scala 有array:_*這樣傳參這種語法,而df的select方法也支援這樣傳,於是最終可以按下面的這樣寫

val cols = colNames.map(f => col(f).cast(DoubleType))
df.select(cols: _*).show()
+----+----+----+----+----+
|col1|col2|col3|col4|col5|
+----+----+----+----+----+
| 1.0| 2.0| 3.0| 4.0| 5.0|
| 6.0| 7.0| 8.0| 9.0|10.0|
+----+----+----+----+----+

這樣就可以很方便的查詢指定多列和轉變指定列的型別了:

val name = "col1,col3,col5"
df.select(name.split(",").map(name => col(name)): _*).show()
df.select(name.split(",").map(name => col(name).cast(DoubleType)): _*).show()
+----+----+----+
|col1|col3|col5|
+----+----+----+
|   1|   3|   5|
|   6|   8|  10|
+----+----+----+

+----+----+----+
|col1|col3|col5|
+----+----+----+
| 1.0| 3.0| 5.0|
| 6.0| 8.0|10.0|
+----+----+----+

附完整程式碼:

package com.dkl.leanring.spark.test

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.types._
import org.apache.spark.sql.DataFrame
object DfDemo {

  def main(args: Array[String]): Unit = {
    val spark = SparkSession.builder().appName("DfDemo").master("local").getOrCreate()
    import org.apache.spark.sql.types._
    val data = Array(("1", "2", "3", "4", "5"), ("6", "7", "8", "9", "10"))
    val df = spark.createDataFrame(data).toDF("col1", "col2", "col3", "col4", "col5")

    import org.apache.spark.sql.functions._
    df.select(col("col1").cast(DoubleType)).show()

    val colNames = df.columns

    var df1 = df
    for (colName <- colNames) {
      df1 = df1.withColumn(colName, col(colName).cast(DoubleType))
    }
    df1.show()

    val cols = colNames.map(f => col(f).cast(DoubleType))
    df.select(cols: _*).show()
    val name = "col1,col3,col5"
    df.select(name.split(",").map(name => col(name)): _*).show()
    df.select(name.split(",").map(name => col(name).cast(DoubleType)): _*).show()

  }