8

I know I can do random splitting with randomSplit method:

val splittedData: Array[Dataset[Row]] = 
        preparedData.randomSplit(Array(0.5, 0.3, 0.2))

Can I split the data into consecutive parts with some 'nonRandomSplit method'?

Apache Spark 2.0.1. Thanks in advance.

UPD: data order is important, I'm going to train my model on data with 'smaller IDs' and test it on data with 'larger IDs'. So I want to split data into consecutive parts without shuffling.

e.g.

my dataset = (0,1,2,3,4,5,6,7,8,9)
desired splitting = (0.8, 0.2)
splitting = (0,1,2,3,4,5,6,7), (8,9)

The only solution I can think of is to use count and limit, but there probably is a better one.

Anton
  • 494
  • 5
  • 19
  • Please elaborate your point ? Using randomSplit to make non random splitting. I don't actually get that – eliasah Dec 02 '16 at 14:59
  • You can give `randomSplit` a seed value to ensure you'll get the same results every time, but I'm not sure that this is what you mean – David Dec 02 '16 at 15:00
  • @eliasah Thanks for feedback, I've updated my question. I'm actually looking for some efficient dataset-partitioning-given-ratios method, please see the example. – Anton Dec 02 '16 at 15:21
  • @David Thanks for feedback, I've updated my question. I'm actually looking for some efficient dataset-partitioning-given-ratios method, please see the example. – Anton Dec 02 '16 at 15:21
  • 1
    Do you know the cutoffs between small/large ids? If so you could filter. If not, you could estimate the percentiles then filter based off of the cutoffs. – David Dec 02 '16 at 15:55
  • @David Thank you David, I've already done this : ) Please see the answer; I'm not sure this solution is the most effective/elegant one, though. Any ideas about a better/shorter solution? – Anton Dec 02 '16 at 17:28

1 Answers1

7

This is the solution I've implemented: Dataset -> Rdd -> Dataset.

I'm not sure whether it is the most effective way to do it, so I'll be glad to accept a better solution.

val count = allData.count()
val trainRatio = 0.6  
val trainSize = math.round(count * trainRatio).toInt
val dataSchema = allData.schema

// Zipping with indices and skipping rows with indices > trainSize.
// Could have possibly used .limit(n) here
val trainingRdd =
  allData
    .rdd
    .zipWithIndex()
    .filter { case (_, index) => index < trainSize }
    .map { case (row, _) => row }

// Can't use .limit() :(
val testRdd =
allData
  .rdd
  .zipWithIndex()
  .filter { case (_, index) => index >= trainSize }
  .map { case (row, _) => row }

val training = MySession.createDataFrame(trainingRdd, dataSchema)
val test = MySession.createDataFrame(testRdd, dataSchema)
Anton
  • 494
  • 5
  • 19