How can we parallelize a loop in Spark so that the processing is not sequential and its parallel. To take an example - I have the following data contained in a csv file (called 'bill_item.csv')that contains the following data:
|-----------+------------|
| bill_id | item_id |
|-----------+------------|
| ABC | 1 |
| ABC | 2 |
| DEF | 1 |
| DEF | 2 |
| DEF | 3 |
| GHI | 1 |
|-----------+------------|
I have to get the output as follows:
|-----------+-----------+--------------|
| item_1 | item_2 | Num_of_bills |
|-----------+-----------+--------------|
| 1 | 2 | 2 |
| 2 | 3 | 1 |
| 1 | 3 | 1 |
|-----------+-----------+--------------|
We see that items 1 and 2 have been found under 2 bills 'ABC' and 'DEF', hence the 'Num_of_bills' for items 1 and 2 is 2. Similarly items 2 and 3 have been found only under bill 'DEF' and hence 'Num_of_bills' column is '1' and so on.
I am using spark to process the CSV file 'bill_item.csv' and I am using the following approaches:
Approach 1:
from pyspark.sql.types import StructType, StructField, IntegerType, StringType
# define the schema for the data
bi_schema = StructType([
StructField("bill_id", StringType(), True),
StructField("item_id", IntegerType(), True)
])
bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))
# find the list of all items in sorted order
item_list = bi_df.select("item_id").distinct().orderBy("item_id").collect()
item_list_len = len(item_list)
i = 0
# for each pair of items for e.g. (1,2), (1,3), (1,4), (1,5), (2,3), (2,4), (2,5), ...... (4,5)
while i < item_list_len - 1:
# find the list of all bill IDs that contain item '1'
bill_id_list1 = bi_df.filter(bi_df.item_id == item_list[i].item_id).select("bill_id").collect()
j = i+1
while j < item_list_len:
# find the list of all bill IDs that contain item '2'
bill_id_list2 = bi_df.filter(bi_df.item_id == item_list[j].item_id).select("bill_id").collect()
# find the common bill IDs in list bill_id_list1 and bill_id_list2 and then the no. of common items
common_elements = set(basket_id_list1).intersection(bill_id_list2)
num_bils = len(common_elements)
if(num_bils > 0):
print(item_list[i].item_id, item_list[j].item_id, num_bils)
j += 1
i+=1
However, this approach is not an efficient approach given the fact that in real life we have millions of records and there may be the following issues:
- There may not be enough memory to load the list of all items or bills
- It may take too long to get the results because the execution is sequential (thanks to the 'for' loop). (I ran the above algorithm with ~200000 records and it took more than 4 hrs to come up with the desired result. )
Approach 2:
I further optimized this by splitting the data on the basis of "item_id" and I used the following block of code to split the data:
bi_df = (spark.read.schema(dataSchema).csv('bill_item.csv'))
outputPath='/path/to/save'
bi_df.write.partitionBy("item_id").csv(outputPath)
After splitting I executed the same algorithm that I used in "Approach 1" and I see that in case of 200000 records, it still takes 1.03 hours(a significant improvement from 4 hours under 'Approach 1') to get the final output.
And the above bottleneck is because of the sequential 'for' loop (and also because of 'collect()' method). So my questions are:
- Is there a way to parallelize the for loop?
- Or is there any other efficient way?