Fast-ish Load Data to SQL from Databricks

I saw a Reddit thread last week about someone’s issue with having to serve Delta table data at very low latency. This is typical of OLTP applications. Sometimes data that is stored in your lake needs to be copied to a RDBMS to provide the speed that is needed in mobile or desktop applications. This is a common problem I have faced in the past years where before things like lakehouse applications, there was no real easy way to accomplish the speed that was needed. Often, we found ourselves copying small sets of data and merging them into target tables, but other times we needed to copy huge sets of data.

What tools like Azure Data Factory are good at sometimes Databricks is not very good at, case and point, copying very large datasets to SQL Server. At one point we were offered a “fast loading” JDBC driver from Microsoft, but there hasn’t been updates to it for years and now we are stuck with the out of box row by row loading driver, which is painfully slow. Because of this I was recently led down the path of trying to remove this dependency from our stack without sacrificing the speed it offered.

In order to accomplish this, we faced two problems: First, understanding how the fast-loading driver works, or at least alternatives to it, and secondly, how to remove dependencies to loading libraries at cluster launch time. I hypothesized that if we could read the data from the lake in a format that SQL understood and issue commands to the SQL engine, we could get close to doing the same thing.

In comes Polybase and writing our own class to handle all of this. The idea was easy, create parquet files in our lake in an area that is exposed to an external data source in SQL then use Polybase to ingest those files into either their target table with a trunc and reload or into a staging table that would be used to merge into a target. We would use traditional ETL tricks like bulk loads and disabling indexes to speed up the loading as much as possible.

At the core of this process is the function insertIntoAzSqlFromBlob. Its arguments are an incoming dataframe (i_df), JDBC connection information (i_jdbcUrl and i_connectionProperties), the name of the table which is ultimately being written to (i_targetTable), wether or not you want a single file to be created (i_singleFile), a variable to handle how verbose the output is (i_config), and whether you want to append or overwrite the target table (i_writeMode).

First a select list is created based off of the schema of the target table. An empty dataframe is created and looped to create a comma separated list of selected columns. This helps remove any extra columns that do not map to the target that may be in the incoming dataframe.

colSql = f"select top 0 * from {i_targetTable}"
targetCols = executeSQL(colSql, i_jdbcUrl, i_connectionProperties).columns
targetSelect = "select "
for targetCol in targetCols:
  targetSelect += f"{targetCol}, "
targetSelect = targetSelect[:-2]

Next we construct the WITH clause of the OPENROWSET command in order to avoid any schema inference problems. We do this by using the getCommonJDBCType which is adapted from the Scala version. This was actually ported to PySpark by using the Databricks Assistant!

withSchema = ""
for dfCol in i_df.schema:
  withSchema = withSchema + f"\n\t{dfCol.name} {getCommonJDBCType(dfCol.dataType)},"
withSchema = withSchema[:-1]

The next block of code is self explanatory. The directory that we write the dataframe to is removed if it exists.

try:
    dbutils.fs.rm(tempPath, True)
    printWithTime(f"Removed: {tempPath}")
except:
    printWithTime(f"Did not find {tempPath}... Will create")

This function is called from a notebook that updates metadata about the processing of data, because of that the number of records that are written is important. The next lines of code caches the dataframe and gets a row count for this purpose.

printWithTime("Getting counts...")   
o_cnt = i_df.cache().count()

Next either a single file is written or multiple files are created. This was done to not overwhelm SQL with too large of a parquet file if a single files is created. Also the idea would be that smaller files fit into memory better and can be loaded faster.

if i_singleFile:
  printWithTime("Writing blob to temp...")  
  i_df.coalesce(1).write.mode("overwrite").option("overwriteSchema", True).parquet(f"{tempPath}")      
else:
  printWithTime("Writing blobs to temp...")  
  i_df.write.mode("overwrite").option("overwriteSchema", True).parquet(f"{tempPath}")   

We need to then get the number of parquet files that were created so we can keep track of progress. We do this by listing the files in the created directory and counting the parquet files.

files = dbutils.fs.ls(tempPath)
parquetCount = 0
for tempFile in files:
  if tempFile.name.endswith(".parquet"):  
    parquetCount += 1

basePath = f"{self.tempURL}_temp_out/"

The next couple of lines of code are a hack to use the JDBC Driver Manager to execute stored procedures on SQL. This is our alternative to using PyODBC since they effectively do the same thing.

driver_manager = spark.sparkContext._gateway.jvm.java.sql.DriverManager
spConn = driver_manager.getConnection(i_jdbcUrl, i_connectionProperties["user"], i_connectionProperties["password"])

To greatly speed up the writing of data to the target, indexes are disabled and the target table is either truncated or not.

truncAndDropSql = f"""
set nocount on;

if lower('{i_writeMode}') != 'append' 
  truncate table {i_targetTable}

declare @id int, @sqlcmd varchar(400)
declare @sqls table (
  id int identity(1, 1),
  sqlcmd varchar(400)
)

insert into @sqls (sqlcmd)
select 'ALTER INDEX [' + i.name + '] ON [' + s.name + '].[' + t.name + '] DISABLE' 
from sys.indexes i
join sys.tables t 
on i.object_id = t.object_id
join sys.schemas s
on t.schema_id = s.schema_id
where i.type_desc = 'NONCLUSTERED'
and i.name is not null
and i.is_disabled = 0
and concat(s.name, '.', t.name) = '{i_targetTable}'

select top 1 @id = id, @sqlcmd = sqlcmd from @sqls

while (@@rowcount > 0)
begin
  exec(@sqlcmd)
  delete from @sqls where id = @id
  select top 1 @id = id, @sqlcmd = sqlcmd from @sqls
end
"""
if isDebug:
  printWithTime(truncAndDropSql)
truncDropStatement = spConn.prepareCall(truncAndDropSql)
truncDropStatement.execute()     
truncDropStatement.close()      

Finally, now that the target table is in a state for bulk loading a table lock bulk load command using Polybase is issued. Note that in order to use Polybase you need to enable it on your SQL Server. There are different ways to do this so please research which is appropriate for your SQL SKU. Once enabled, you will need to create a Master Key, a Credential, and an External Data Source. You can use SAS keys or a Managed Identity. If you go down the path of Managed Identity, you will need to grant your SQL Server’s identity access to you ALDS as at least Storage Blob Reader as well as give it the appropriate ACLs in the storage container.

printWithTime("Executing bulk inserts...")
readFiles = 0
for tempFile in files:
  tempFileName = tempFile.path
  if tempFileName.endswith(".parquet"):     
    readFiles += 1
    printWithTime(f"Processing file {readFiles} of {parquetCount}...")   
    polybasePath = tempFileName.replace(basePath, "")
    bulkInsertSql = f"""    
    set nocount on;
    insert into {i_targetTable} with(tablock)
    {targetSelect} 
    from openrowset(
      bulk '{polybasePath}',
      data_source = '<yourDataSourceName>',
      format = 'parquet'
    ) with ( {withSchema} \n) as tempBlob"""        

    if isDebug:
      printWithTime(bulkInsertSql)
      
    procStatement = spConn.prepareCall(bulkInsertSql)
    procStatement.execute()     
    procStatement.close() 

And lastly, once all parquet files have been bulk loaded, the target table has its indexes reenabled and the dataframe’s row count is returned.

enableSql = f"ALTER INDEX ALL ON {i_targetTable} REBUILD;"

printWithTime("Enabling indexes...")  
if isDebug:
  printWithTime(enableSql)
enableStatement = spConn.prepareCall(enableSql)
enableStatement.execute()     
enableStatement.close()     
  
return o_cnt

The typical implementation of this class is done in a Notebook where a dataframe is created and passed to this function. When using the target table as a staging table to then use to merge into an existing table you can use something like the following function to execute a merge statement, but normally we use this process to handle full loads of tables.

def executeSQLCommandNoReturn(self, i_sqlCmd, i_jdbcUrl, i_SQLUsername, i_SQLPassword):
    driver_manager = spark.sparkContext._gateway.jvm.java.sql.DriverManager  
    spConn = driver_manager.getConnection(i_jdbcUrl, i_SQLUsername, i_SQLPassword)
    procStatement = spConn.prepareCall(i_sqlCmd)
    procStatement.execute()     
    procStatement.close() 
    spConn.close()

The only thing I could not figure out with this process is if it is possible to multi thread the calling of bulk insert SQL commands. Of course the table lock would need to be removed, but I think hypothetically if the table was partitioned in the same manner as the incoming dataframe, this could be really really fast. If anyone can figure that out, I’d love to see how it is done.

The entire snippet of code can be found below:

from databricks.sdk.runtime import *
from pyspark.sql import SparkSession as ss

from pyspark.sql.types import *
from pyspark.sql.functions import *
from datetime import datetime

def printWithTime(i_str):
  # Print a string with the current timestamp
  now = datetime.now()
  modifiedDatetime = now.strftime("%Y-%m-%d %H:%M:%S")
  print(f"{modifiedDatetime}: {i_str}")

def executeSQL(i_sqlSelect, i_jdbcUrl, i_jdbcConnectionProperties):
  # Execute a SQL query and return the result as a DataFrame
  i_sqlSelect = "(" + i_sqlSelect + ") AS SQLTable"

  o_df = spark.read.jdbc(url = i_jdbcUrl, table = i_sqlSelect, properties = i_jdbcConnectionProperties)
  return o_df

def getCommonJDBCType(dt):
  if isinstance(dt, IntegerType):
    return ("INTEGER")
  elif isinstance(dt, LongType):
    return ("BIGINT")
  elif isinstance(dt, DoubleType):
    return ("FLOAT")
  elif isinstance(dt, FloatType):
    return ("FLOAT")
  elif isinstance(dt, ShortType):
    return ("SMALLINT")
  elif isinstance(dt, ByteType):
    return ("TINYINT")
  elif isinstance(dt, BooleanType):
    return ("BIT")
  elif isinstance(dt, StringType):
    return (f"NVARCHAR(MAX)")
  elif isinstance(dt, CharType):
    return (f"CHAR({dt.length})")
  elif isinstance(dt, VarcharType):
    return (f"VARCHAR({dt.length})")
  elif isinstance(dt, TimestampType):
    return ("DATETIME")
  elif isinstance(dt, TimestampNTZType):
    return ("DATETIME")
  elif isinstance(dt, DateType):
    return ("DATE")
  elif isinstance(dt, DecimalType):
    return (f"DECIMAL({dt.precision},{dt.scale})")
  else:
    return None

class sql_helper:
  def __init__(self):
    try:
        printWithTime("Created helper class")
    except BaseException as e:
      if "Attribute `sparkContext` is not supported" in str(e):
        raise Exception("SparkContext is not available. Please run this notebook on a single user UC enabled cluster.")
      else: 
        raise Exception(str(e))
      
  def insertIntoAzSqlFromBlob(self, i_df, i_jdbcUrl, i_connectionProperties, i_targetTable, i_singleFile, i_config, i_writeMode):    
    # Bulk inserts a DataFrame to an Azure SQL table from a blob or multiple blobs based on the i_singleFile parameter
    if i_config == "Debug":
        isDebug = True
    else:
        isDebug = False

    # Build select list based off of the schema of the target table
    colSql = f"select top 0 * from {i_targetTable}"
    targetCols = executeSQL(colSql, i_jdbcUrl, i_connectionProperties).columns
    targetSelect = "select "
    for targetCol in targetCols:
      targetSelect += f"{targetCol}, "
    targetSelect = targetSelect[:-2]

    # Define the schema of the bulk insert based off of common JDBC types xref to spark types
    withSchema = ""
    for dfCol in i_df.schema:
      withSchema = withSchema + f"\n\t{dfCol.name} {getCommonJDBCType(dfCol.dataType)},"
    withSchema = withSchema[:-1]
    
    tempPath = f"{self.tempURL}_temp_out/{i_targetTable}/"

    # Remove the directory for the persisted parquet files if it exists
    try:
        dbutils.fs.rm(tempPath, True)
        printWithTime(f"Removed: {tempPath}")
    except:
        printWithTime(f"Did not find {tempPath}... Will create")

    # Count the incoming DataFrame and cache it
    printWithTime("Getting counts...")   
    o_cnt = i_df.cache().count()

    # Write a single large blob or multiple parquet files
    if i_singleFile:
      printWithTime("Writing blob to temp...")  
      i_df.coalesce(1).write.mode("overwrite").option("overwriteSchema", True).parquet(f"{tempPath}")      
    else:
      printWithTime("Writing blobs to temp...")  
      i_df.write.mode("overwrite").option("overwriteSchema", True).parquet(f"{tempPath}")      

    # Count the number of parquet files in the target folder to keep track of progress later
    files = dbutils.fs.ls(tempPath)
    parquetCount = 0
    for tempFile in files:
      if tempFile.name.endswith(".parquet"):  
        parquetCount += 1

    #all storage folders are suffixed with _temp_out 
    basePath = f"{self.tempURL}_temp_out/"
    
    # Use the driver manager to get a connection to the target database
    driver_manager = spark.sparkContext._gateway.jvm.java.sql.DriverManager
    spConn = driver_manager.getConnection(i_jdbcUrl, i_connectionProperties["user"], i_connectionProperties["password"])

    if i_writeMode.lower() != 'append':
      printWithTime("Truncating target table and disabling indexes...")  
    else:
      printWithTime("Disabling indexes...")

    # Run the following SQL that will truncate the target table if required and disable any non-clustered indexes if they exist on the target table
    truncAndDropSql = f"""
    set nocount on;

    if lower('{i_writeMode}') != 'append' 
      truncate table {i_targetTable}

    declare @id int, @sqlcmd varchar(400)
    declare @sqls table (
      id int identity(1, 1),
      sqlcmd varchar(400)
    )

    insert into @sqls (sqlcmd)
    select 'ALTER INDEX [' + i.name + '] ON [' + s.name + '].[' + t.name + '] DISABLE' 
    from sys.indexes i
    join sys.tables t 
    on i.object_id = t.object_id
    join sys.schemas s
    on t.schema_id = s.schema_id
    where i.type_desc = 'NONCLUSTERED'
    and i.name is not null
    and i.is_disabled = 0
    and concat(s.name, '.', t.name) = '{i_targetTable}'

    select top 1 @id = id, @sqlcmd = sqlcmd from @sqls

    while (@@rowcount > 0)
    begin
      exec(@sqlcmd)
      delete from @sqls where id = @id
      select top 1 @id = id, @sqlcmd = sqlcmd from @sqls
    end
    """
    if isDebug:
      printWithTime(truncAndDropSql)
    truncDropStatement = spConn.prepareCall(truncAndDropSql)
    truncDropStatement.execute()     
    truncDropStatement.close()      
    
    # Loop the files in the persisted folder of parquet files and execute a bulk insert for each one
    printWithTime("Executing bulk inserts...")
    readFiles = 0
    for tempFile in files:
      tempFileName = tempFile.path
      if tempFileName.endswith(".parquet"):     
        readFiles += 1
        printWithTime(f"Processing file {readFiles} of {parquetCount}...")   
        polybasePath = tempFileName.replace(basePath, "")
        bulkInsertSql = f"""    
        set nocount on;
        insert into {i_targetTable} with(tablock)
        {targetSelect} 
        from openrowset(
          bulk '{polybasePath}',
          data_source = '<yourDataSourceName>',
          format = 'parquet'
        ) with ( {withSchema} \n) as tempBlob"""        

        if isDebug:
          printWithTime(bulkInsertSql)
          
        procStatement = spConn.prepareCall(bulkInsertSql)
        procStatement.execute()     
        procStatement.close()       
      
    # Once complete then rebuild the indexes on the target table
    enableSql = f"ALTER INDEX ALL ON {i_targetTable} REBUILD;"
    
    printWithTime("Enabling indexes...")  
    if isDebug:
      printWithTime(enableSql)
    enableStatement = spConn.prepareCall(enableSql)
    enableStatement.execute()     
    enableStatement.close()     
      
    return o_cnt

Leave a Reply

Your email address will not be published. Required fields are marked *