2

I need an efficient way to list and drop unary columns in a Spark DataFrame (I use the PySpark API). I define a unary column as one which has at most one distinct value and for the purpose of the definition, I count null as a value as well. That means that a column with one distinct non-null value in some rows and null in other rows is not a unary column.

Based on the answers to this question I managed to write an efficient way to obtain a list of null columns (which are a subset of my unary columns) and drop them as follows:

counts = df.summary("count").collect()[0].asDict()
null_cols = [c for c in counts.keys() if counts[c] == '0']
df2 = df.drop(*null_cols)

Based on my very limited understanding of the inner workings of Spark this is fast because the method summary manipulates the entire data frame simultaneously (I have roughly 300 columns in my initial DataFrame). Unfortunately, I cannot find a similar way to deal with the second type of unary columns - ones which have no null values but are lit(something).

What I currently have is this (using the df2 I obtain from the code snippet above):

prox_counts = (df2.agg(*(F.approx_count_distinct(F.col(c)).alias(c)
                         for c in df2.columns
                         )
                       )
                  .collect()[0].asDict()
               )
poss_unarcols = [k for k in prox_counts.keys() if prox_counts[k] < 3]
unar_cols = [c for c in poss_unarcols if df2.select(c).distinct().count() < 2]

Essentially, I first find columns which could be unary in a fast but approximate way and then look at the "candidates" in more detail and more slowly.

What I don't like about it is that a) even with the approximative pre-selection it is still fairly slow, taking over a minute to run even though at this point I only have roughly 70 columns (and about 6 million rows) and b) I use the approx_count_distinct with the magical constant 3 (approx_count_distinct does not count null, hence 3 instead of 2). Since I'm not exactly sure how the approx_count_distinct works internally I am a little worried that 3 is not a particularly good constant since the function might estimate the number of distinct (non-null) values as say 5 when it really is 1 and so maybe a higher constant is needed to guarantee nothing is missing in the candidate list poss_unarcols.

Is there a smarter way to do this, ideally so that I don't even have to drop the null columns separately and do it all in one fell swoop (although that is actually quite fast and so that big a big issue)?

Masoud Rahimi
  • 5,785
  • 15
  • 39
  • 67
Trademark
  • 33
  • 1
  • 6

4 Answers4

0

I suggest that you have a look at the following function

pyspark.sql.functions.collect_set(col)

https://spark.apache.org/docs/latest/api/python/pyspark.sql.html?highlight=dataframe

It shall return all the values in col with multiplicated elements eliminated. Then you can check for the length of result (whether it equals one). I would be wondering about performance but I think it will beat distinct().count() definitely. Lets have a look on Monday :)

bazinac
  • 668
  • 5
  • 22
0

you can df.na.fill("some non exisitng value").summary() and then drop the relevant columns from the original dataframe

Arnon Rotem-Gal-Oz
  • 25,469
  • 3
  • 45
  • 68
  • The problem with this is that columns have various datatypes, so I would have to call this several times. Also, I realized that my definition of unary means that once I remove the fully null columns, I only need to check the columns which have no null values in them - all others have at least 1 row equal to null and 1 row equal to something else and so they are not unary. – Trademark Apr 16 '19 at 08:29
0

So far the best solution I found is this (it is faster than the other proposed answers, although not ideal, see below):

rows = df.count()
nullcounts = df.summary("count").collect()[0].asDict()
del nullcounts['summary']
nullcounts = {key: (rows-int(value)) for (key, value) in nullcounts.items()}

# a list for columns with just null values
null_cols = []
# a list for columns with no null values
full_cols = []

for key, value in nullcounts.items():
    if value == rows:
        null_cols.append(key)
    elif value == 0:
        full_cols.append(key)

df = df.drop(*null_cols)

# only columns in full_cols can be unary
# all other remaining columns have at least 1 null and 1 non-null value
try:
    unarcounts = (df.agg(*(F.countDistinct(F.col(c)).alias(c) for c in full_cols))
                    .collect()[0]
                    .asDict()
                  )
    unar_cols = [key for key in unarcounts.keys() if unarcounts[key] == 1]
except AssertionError:
    unar_cols = []

df = df.drop(*unar_cols)

This works reasonably fast, mostly because I don't have too many "full columns", i.e. columns which contain no null rows and I only go through all rows of these, using the fast summary("count") method to clasify as many columns as I can.

Going through all rows of a column seems incredibly wasteful to me, since once two distinct values are found, I don't really care what's in the rest of the column. I don't think this can be solved in pySpark though (but I am a beginner), this seems to require a UDF and pySpark UDFs are so slow that it is not likely to be faster than using countDistinct(). Still, as long as there are many columns with no null rows in a dataframe, this method will be pretty slow (and I am not sure how much one can trust approx_count_distinct() to differentiate between one or two distinct values in a column)

As far as I can say it beats the collect_set() approach and filling the null values is actually not necessary as I realized (see the comments in the code).

Trademark
  • 33
  • 1
  • 6
0

I tried your solution, and it was too slow in my situation, so I simply grabbed the first row of the data frame and checked for duplicates. This turned out to be far more performant. I'm sure there's a better way, but I don't know what it is!

first_row = df.limit(1).collect()[0]

drop_cols = [
    key for key, value in df.select(
        [
            sqlf.count(
                sqlf.when(sqlf.col(column) != first_row[column], column)
            ).alias(column)
            for column in df.columns
        ]
    ).collect()[0].asDict().items() 
    if value == 0
]

df = df.drop(*[drop_cols])
James
  • 697
  • 5
  • 23