Forgive me, as I'm much more familiar with the Scala API, but perhaps you can easily convert it:
scala> val df = spark.createDataset(Seq(
| ("Q31", "P36", "Q239"),
| ("Q31", "P625", "51"),
| ("Q45", "P36", "Q597"),
| ("Q45", "P625", "123"),
| ("Q51", "P625", "22"),
| ("Q24", "P625", "56")
| )).toDF("s", "p", "o")
df: org.apache.spark.sql.DataFrame = [s: string, p: string ... 1 more field]
scala> (df.select($"s", struct($"p", $"o").as("po"))
| .groupBy("s")
| .agg(collect_list($"po").as("polist"))
| .as[(String, Array[(String, String)])]
| .flatMap(r => {
| val ps = r._2.map(_._1).toSet
| if(ps("P625") && ps("P36")) {
| r._2.flatMap(po => Some(r._1, po._1, po._2))
| } else {
| None
| }
| }).toDF("s", "p", "o")
| .show())
+---+----+----+
| s| p| o|
+---+----+----+
|Q31| P36|Q239|
|Q31|P625| 51|
|Q45| P36|Q597|
|Q45|P625| 123|
+---+----+----+
For reference, your join()
command above would have returned:
scala> df.filter($"p" === "P625").join(df.filter($"p" === "P36"), "s").show
+---+----+---+---+----+
| s| p| o| p| o|
+---+----+---+---+----+
|Q31|P625| 51|P36|Q239|
|Q45|P625|123|P36|Q597|
+---+----+---+---+----+
Which can be worked into your final solution as well, perhaps with less code, but I'm not sure which method would be more efficient, as that's largely data dependent.