1

I have a Spark dataframe named df as input:

+---------------+---+---+---+---+
|Main_CustomerID| A1| A2| A3| A4|
+---------------+---+---+---+---+
|            101|  1|  0|  2|  1|
|            102|  0|  3|  1|  1|
|            103|  2|  1|  0|  0|
+---------------+---+---+---+---+

I need to collect the values of A1, A2, A3, A4 into a mllib matrix such as,

dm: org.apache.spark.mllib.linalg.Matrix =
1.0  0.0  2.0  1.0
0.0  3.0  1.0  1.0
2.0  1.0  0.0  0.0

How can I achieve this in Scala?

Shaido
  • 27,497
  • 23
  • 70
  • 73
PRIYA M
  • 181
  • 2
  • 3
  • 19

1 Answers1

3

You can do it as follows, first get all columns that should be included in the matrix:

import org.apache.spark.sql.functions._

val matrixColumns = df.columns.filter(_.startsWith("A")).map(col(_))

Then convert the dataframe to an RDD[Vector]. Since the vector need to contain doubles this conversion need to be done here too.

import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.linalg.distributed.{IndexedRow, IndexedRowMatrix}

val rdd = df.select(array(matrixColumns:_*).as("arr")).as[Array[Int]].rdd
  .zipWithIndex()
  .map{ case(arr, index) => IndexedRow(index, Vectors.dense(arr.map(_.toDouble)))}

Then convert the rdd to an IndexedRowMatrix which can be converted, if required, to a local Matrix:

val dm = new IndexedRowMatrix(rdd).toBlockMatrix().toLocalMatrix()

For smaller matrices that can be collected to the driver there is an easier alternative:

val matrixColumns = df.columns.filter(_.startsWith("A")).map(col(_))

val arr = df.select(array(matrixColumns:_*).as("arr")).as[Array[Int]]
  .collect()
  .flatten
  .map(_.toDouble)

val rows = df.count().toInt
val cols = matrixColumns.length

// It's necessary to reverse cols and rows here and then transpose
val dm = Matrices.dense(cols, rows, arr).transpose()
rasthiya
  • 650
  • 1
  • 6
  • 20
Shaido
  • 27,497
  • 23
  • 70
  • 73
  • This answer was helpful. But why .as[Array[Int]]? – PRIYA M Jul 05 '18 at 07:24
  • 1
    @PRIYAM: Happy to help :) The `.as[]` part is just to make it more clear when converting to a rdd or collecting a dataframe. The alternative is to do for example: `df.rdd.map(_.getAs[Array[Int]](0))`. – Shaido Jul 05 '18 at 07:31