1

We have unity catalog configured in our databricks environment.We have some functions which will connect with the tables using spark.sql(" the sql code") and retrieve data. we need want to write test cases for those functions. we were able to mock the query by giving a predefined result and created a spark session using the below code

class TestMyFunction(unittest.TestCase):
    def setUp(self):
        global spark
        self.spark = SparkSession.builder.master("local").appName("test").getOrCreate()
        spark = self.spark
        self.get_breakpoint_and_lob_id_df = self.spark.createDataFrame([(1, 2)], ["LOB_ID", "BREAKPOINT_ID"])

    def tearDown(self):
        patch.stopall()

The test cases gets executed successfully when run in databricks environment. but when we deploy the code and run the unit test cases in a Jenkins server, it is throwing error saying that spark is not defined. this is one of the test case which we have wrote


def test_final_flag_generate(self):
        self.mock_sql.side_effect = [self.final_flag_generate_df]
        catalog_name = "test_catalog"
        val_dt = "2022-01-01"
        iteration = 1
        
        result_flag, result_df = final_flag_generate(self.final_flag_generate_df, catalog_name, val_dt, iteration)
        
        expected_flag = 1
        expected_columns = self.final_flag_generate_df.columns
        
        self.assertEqual(expected_flag, result_flag)
        self.assertListEqual(expected_columns, result_df.columns)

this is the original function

def final_flag_generate(joined_sql,catalog_name, val_dt, iteration):
    mapping_id_list=[]
    for row in joined_sql.collect():
        mapping_id=row.MAPPING_ID
        mapping_id_list.append(mapping_id)
    id_tuple= tuple(mapping_id_list)
    print(id_tuple)
    final_calculation_df=spark.sql(f"""select * from {catalog_name}.db1.tbl1 where mapping_id in {id_tuple} and val_dt='{val_dt}' and iteration={iteration} """)


    count_ones = final_calculation_df.filter(col("balance_flag") == 1).count()
    print(count_ones)
    total_count = final_calculation_df.count()
    print(total_count)

    if count_ones == total_count:
        final_result_flag=1
    else:
        final_result_flag=0
    return final_result_flag,final_calculation_df

how to write test case when spark.sql is involved in the function..?

1 Answers1

0

This simple base class has been pretty handful in various projects I've been in:

class BaseSparkTest(BaseTest, abc.ABC):
  spark: SparkSession = None

  def setUp(self) -> None:
      if not self.spark:
          self.spark = (
              SparkSession.builder.master("local")
              .config("spark.sql.shuffle.partitions", "5")
              .getOrCreate()
          )

  @classmethod
  def tearDownClass(cls) -> None:
      if BaseSparkTest.spark:
          BaseSparkTest.spark.stop()

Whenever possible, I use the same session for the all the tests in the class, making sure not to leak data or information between tests. When not possible, you have to recycle the session per test.

shay__
  • 3,815
  • 17
  • 34
  • Thanks for the reply. Even i use similar logic that is mentioned above. It is works fine for functions where the dataframe is given as arg and output is received. However when I call spark inside the original function like spark.read.csv and spark. SQL. it fails saying spark is not defined. any solution for that.? – Saravana Kumar Jul 17 '23 at 07:19
  • Sounds more like generic Python issue, rather than Spark specific issue. – shay__ Jul 17 '23 at 08:58