As new arrays functions were introduced in Spark 2.4, you have to go to user-defined function (udf).
User-defined functions in java are java objects with an apply
method that can be used as a built-in function in dataframe transformation. To create such object, you first create an UDFx
object, where x
is the number of arguments of your udf.
Then you create your udf from this UDFx
object, either by registering it with method sparkSession.sqlContext().register().udf()
(only method available before Spark 2.3) or by creating it with function udf
(for Spark 2.3 and greater) as described in this answer.
And finally you use it with function callUdf
or directly using apply
. So complete code for Spark 2.3 and above is as follow:
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.api.java.UDF1;
import org.apache.spark.sql.expressions.UserDefinedFunction;
import org.apache.spark.sql.types.DataTypes;
import scala.collection.Seq;
import java.util.stream.Collectors;
import static org.apache.spark.sql.functions.col;
import static org.apache.spark.sql.functions.udf;
import static scala.collection.JavaConverters.asScalaBuffer;
import static scala.collection.JavaConverters.seqAsJavaList;
public class Flattener {
public static Dataset<Row> flattenArray(Dataset<Row> input, String columnName) {
// define what your user-defined function do
UDF1<Seq<Seq<String>>, Seq<String>> flattenArray = new UDF1<Seq<Seq<String>>, Seq<String>>() {
@Override
public Seq<String> call(Seq<Seq<String>> s) {
return asScalaBuffer(
seqAsJavaList(s)
.stream()
.flatMap(x -> seqAsJavaList(x).stream())
.collect(Collectors.toList())
).toSeq();
}
};
// convert it to user-defined function
UserDefinedFunction flatten_array = udf(
flattenArray,
DataTypes.createArrayType(DataTypes.StringType) // output type of your UDF
);
// apply your user-defined function
return input.withColumn(columnName, flatten_array.apply(col(columnName)));
}
}
Note: when working with sequence on java UDF, you need to use Scala Seq
instead of java List
as sequence input. To convert from one to another, look at JavaConverters
scala class methods.
You can then call your flattenArray
method on your dataframe:
Flattener.flattenArray(dataframe, "name_of_column_you_want_to_flatten");