0

I have a table of three columns [s,p,o]. I would like to remove rows, that for each entry in s , the p column does not include [P625, P36] values both. For example

+----+----+------
|   s|   p|  o  |
+----+----+-----|
| Q31| P36| Q239|
| Q31|P625|   51|
| Q45| P36| Q597|
| Q45|P625|  123|
| Q51|P625|   22|
| Q24|P625|   56|

The end result should be

+----+----+------
|   s|   p|  o  |
+----+----+-----|
| Q31| P36| Q239|
| Q31|P625|   51|
| Q45| P36| Q597|
| Q45|P625|  123|

Using join operation, the above task is easy.

df.filter(df.p=='P625').join(df.filter(df.p=='P36'),'s')

But is there a more elegant way to do this?

user1848018
  • 1,086
  • 1
  • 14
  • 33
  • I haven't used this technology, but you may see if it will allow you to do df.filter(df.p=='P265' || df.p=='P36') and you may find something here: https://stackoverflow.com/a/35882046/2793683 – dmoore1181 Feb 28 '19 at 20:23
  • @pault and dmoore1181 both suggested queries won't remove the last three rows of the original table in the given example. – user1848018 Feb 28 '19 at 20:33
  • 1
    It is is not duplicate, I am not conditioning on different columns of the same row – user1848018 Feb 28 '19 at 20:34
  • @user1848018 I see what you mean now. I think you have to do a join in this case. There may be other ways, but the join is likely most efficient. – pault Feb 28 '19 at 20:36
  • That `join()` will not end up with your end result example either. I think you're likely going to have to turn `p` and `o` into a single column `struct()`, then do `.groupBy()`, `.agg()`, `.filter()`, then `.flatMap()` to get your end result example. – Travis Hegner Feb 28 '19 at 21:37

2 Answers2

1

Forgive me, as I'm much more familiar with the Scala API, but perhaps you can easily convert it:

scala> val df = spark.createDataset(Seq(
     |      ("Q31", "P36", "Q239"),
     |      ("Q31", "P625", "51"),
     |      ("Q45", "P36", "Q597"),
     |      ("Q45", "P625", "123"),
     |      ("Q51", "P625", "22"),
     |      ("Q24", "P625", "56")
     | )).toDF("s", "p", "o")
df: org.apache.spark.sql.DataFrame = [s: string, p: string ... 1 more field]

scala> (df.select($"s", struct($"p", $"o").as("po"))
     |   .groupBy("s")
     |   .agg(collect_list($"po").as("polist"))
     |   .as[(String, Array[(String, String)])]
     |   .flatMap(r => {
     |     val ps = r._2.map(_._1).toSet
     |           if(ps("P625") && ps("P36")) {
     |             r._2.flatMap(po => Some(r._1, po._1, po._2))
     |           } else {
     |             None
     |           }
     |   }).toDF("s", "p", "o")
     |   .show())
+---+----+----+                                                                 
|  s|   p|   o|
+---+----+----+
|Q31| P36|Q239|
|Q31|P625|  51|
|Q45| P36|Q597|
|Q45|P625| 123|
+---+----+----+

For reference, your join() command above would have returned:

scala> df.filter($"p" === "P625").join(df.filter($"p" === "P36"), "s").show
+---+----+---+---+----+
|  s|   p|  o|  p|   o|
+---+----+---+---+----+
|Q31|P625| 51|P36|Q239|
|Q45|P625|123|P36|Q597|
+---+----+---+---+----+

Which can be worked into your final solution as well, perhaps with less code, but I'm not sure which method would be more efficient, as that's largely data dependent.

Travis Hegner
  • 2,465
  • 1
  • 12
  • 11
1

You need a window

from pyspark.sql import Window
from pyspark.sql.functions import *

winSpec = Window.partitionBy('s')
df.withColumn("s_list", collect_list("s").over(winSpec)).
filter(array_contains(col("s_list"), "P625") & array_contains(col("s_list"), "P36") & size(col("s_list")) = 2)
ayplam
  • 1,943
  • 1
  • 14
  • 20