Unit Testing Databricks Notebooks – Part 3

In the last two sections, we’ve covered the overall approach to the problem of unit testing notebooks as well as a notebook structure that enables source data from different databases, but how does the data get to those databases?

As I explained in the first post, we are using the medallion architecture. On normal runs, the query in the example would source its data from Silver zone delta tables, but during test runs it sources from a database called “unit_test_db”.

We need to create some sort of process that takes seed data and populates the tables in this database, and this is where databricks-connect and a Python project come into play.

By using databricks-connect we can create a hybrid approach to unit testing. There are other libraries out there that can be used in Python projects that execute notebooks, but don’t offer the flexibility that we’re looking for, so we wrote our own.

The project structure is as follows:

UnitTests
    |___configs
        |.databricks-connect
    |___unittests
        |test_simple.py
    |___utils
        |compareutils.py
        |configutils.py
        |databricksutils.py
        |sparkutils.py
    |main.py
    |requirements.txt

Each of the utils are created to help streamline the creation of unit tests that are run on notebooks.

  • compareutils.py
    • A utility that takes two DataFrames with the same schema and does a compare to determine leftNotInRight, rightNotInLeft, sameInBoth, and changes
  • configutils.py
    • A utility that uses the config you have in your .databricks-connect file when calling the Databricks API
  • databricksutils.py
    • A utility that calls the notebook, checks state, and returns the output
  • sparkutils.py
    • A utility that creates databases and tables in Spark SQL

If we look at the data we created in both the seed data and the assertions we can begin to see how this process works.

Seed Data

seed_BusinessUnitMaster.csv
Subsidiary
570
999
120
seed_UserDefinedCodes.csv
userDefinedCodes,productCode,userDefinedCode,Description
18,00,999,Subsidiary 1
18,00,570,Subsidiary 2
18,00,120,Subsidiary 3

The query that creates the end result is very simple to show how these two files can be joined and a surrogate key can be created. If we look at the assertion data:

subsidiarySK,subsidiaryNumber,subsidiaryName
-1318448993368083429,120,Subsidiary 3
-1488953326447475322,570,Subsidiary 2
-5381099075221240468,999,Subsidiary 1

We can see that we expect three records that represent the join and the creation of a surrogate key.

“But how does this all fit together?!” you might be asking. Let’s take a look at the pyunit test that utilizes each utility we created and walk through each section in depth.

Initialization and Spark setup

import pytest
import json
from pyspark.sql import SparkSession
from pyspark.dbutils import DBUtils
from utils import sparkutils, databricksutils, compareutils, configutils
spark = SparkSession.builder.enableHiveSupport().getOrCreate()
spark.sparkContext.setLogLevel("OFF")
dbutils = DBUtils(spark)

Since we’ve setup databricks-connect we must initialize a Spark session to run on our Databricks cluster. This is done in the first part of the unit test.

Creating tables linked to seed files

def test_simple_transform_differences():
    database = "unit_test_db"
    tables = [
        {"Name": "jde_vwBusinessUnitMaster",
         "Path": "/mnt/erdw/test/SeedData/Test/seed_BusinessUnitMaster.csv"},
        {"Name": "jde_vwUserDefinedCodeValues",
         "Path": "/mnt/erdw/test/SeedData/Test/seed_UserDefinedCodes.csv"}
    ]
    
    seedTables = sparkutils.util(database, tables)
    seedTables.create()

PyUnit tests must be prefixed or suffixed with “test” in order to be collected. Here we use a database name and dictionary of table names and paths to files to initialize a sparkutils.util class. This class creates tables in Databricks in the database and table name provided. These must match up to what’s being called in the notebook. In this example we use “unit_test_db”.

Using Configs to run notebook

    configs = configutils.util()
    #Run the notebook as a unit test
    notebookInstance = databricksutils.instance(configs.getHost(), configs.getToken(), configs.getClusterId())
    args = {"scopeName" : "sb-databricks-keyvault"
        , "configuration" : "Test"
    }
    notebookInstance.runNotebook("/ERDW/TranLoads/TRAN_Simple_Test", args)
    if notebookInstance.getRunResultState() == "SUCCESS":
        #get returns from notebook run
        print(notebookInstance.getRunOutput())
        testNotebookJson = json.loads(notebookInstance.getRunOutput())
        testFilePath = testNotebookJson["FilePath"]
        testStatus = testNotebookJson["ExecutionStatus"]

In this section the databricksutils.instance class is used to create a notebook run on an all purpose cluster via the API. The same configurations used to connect to your cluster from your IDE that are kept in config/.databricks-connect are used here via the configutils.util class to make the API call via web-request. The util helps submit the job, check the state periodically, and return both the state outcome and the output if the run is successful.

Check output against expected


        if (testStatus == "Pass"):
            # Arrange - limit returned data to what is in assertions
            testFileDF = spark.read.format("parquet").load(testFilePath)
            testFileDF.createOrReplaceTempView("returned_dimSubsidiaryJDE")
            returnedDF = spark.sql("""select 
              subsidiarySK
            , subsidiaryNumber
            , subsidiaryName
            from returned_dimSubsidiaryJDE """)
            # Assert - load expected results
            expectedPath = "/mnt/erdw/test/ExpectedData/Test/expected_dimSubsidiary.csv"
            expectedDF = spark.read.format("csv").option("header", "true").option("nullValue", None).option(
                "emptyValue", None).load(expectedPath)
            expectedCount = expectedDF.count()
            # Assert - write out test files
            diffDF = compareutils.diffdataframe(returnedDF, expectedDF, "subsidiarySK")
            assert diffDF.leftNotInRight == 0, f"Failed left not in right. Expected 0 and got {diffDF.leftNotInRight}."
            assert diffDF.rightNotInLeft == 0, f"Failed right not in left. Expected 0 and got {diffDF.rightNotInLeft}."
            assert diffDF.changes == 0, f"Failed changes. Expected 0 and got {diffDF.changes}."
            assert diffDF.sameInBoth == expectedDF.count(), f"Failed same in both. Expected {diffDF.sameInBoth} and got {expectedCount}."

The last section of the unit tests takes the output file path of the notebook run. If successful, reads it into a DataFrame, creates a temp view and makes another DataFrame from the SQL. This seems redundant, but in some cases when files are extremely big and we want to only test certain fields, we can use the SQL to adjust how many checks we want to do.

Next, the expected results are loaded to a DataFrame and then compared to the returned results via the compareutils.diffdataframe class. This class returns a number of counts depending on the differences between the two DataFrames.

Assertions are then tested based on the expected counts from each difference test. We expect there to be no rows in the leftNotInRight, rightNotInLeft, and changes DataFrames, as well as a matching count between sameInBoth and expected DataFrames.

Running Unit Tests

If you’ve setup your IDE to run on your cluster correctly, running the unit tests locally is as easy as calling the following:

In the next section we will see how we can incorporate this into a CI/CD pipeline to run these tests automatically every time we check in a change to Databricks.

Complete File

The complete file can be found here.

https://github.com/CharlesRinaldini/Databricks-UnitTestNotebooks/blob/master/unittests/test_simple.py

Continue to Part 4

Leave a Reply

Your email address will not be published. Required fields are marked *