Source code for core_db.engines.snowflake_

# -*- coding: utf-8 -*-

"""
Snowflake Data Warehouse Client Module
========================================

This module provides the SnowflakeClient class for connecting to and interacting
with Snowflake Data Warehouse using the snowflake-connector-python library.
"""

import json
import re
from datetime import date
from datetime import datetime
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import overload

import snowflake.connector

from core_db.interfaces.sql_based import ISqlDatabaseClient


[docs] class SnowflakeClient(ISqlDatabaseClient): """ Client for Snowflake Data Warehouse connection. This client provides secure database operations using parameterized queries and input validation to prevent SQL injection attacks. It supports standard CRUD operations, batch inserts, and upserts (MERGE). =================================================== Usage Examples =================================================== Basic Connection and Query: ---------------------------- .. code-block:: python from core_db.engines.snowflake_ import SnowflakeClient config = { "user": "username", "password": "password", "account": "account_name", "warehouse": "warehouse_name", "database": "database_name", "schema": "schema_name" } with SnowflakeClient(**config) as client: client.execute("SELECT CURRENT_VERSION();") print(client.fetch_one()) .. Security Features: ------------------ - Validates all SQL identifiers (table/column names) against injection patterns - Escapes string values to prevent SQL injection - Supports parameterized queries where applicable - Validates fully qualified names (schema.table) """ TYPE_MAPPER = { int: "INTEGER", float: "DOUBLE", str: "VARCHAR", bool: "BOOLEAN", dict: "OBJECT", list: "VARIANT" } VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*(\.[a-zA-Z_][a-zA-Z0-9_]*)*$")
[docs] def __init__(self, **kwargs) -> None: """ :param kwargs: * user: Username. * host: Hostname. * account: Account name. * password: Password. * warehouse: Warehouse. * database: Database. * schema: Schema. * role: Role. To connect using OAuth, the connection string must include the authenticator parameter set to oauth and the token parameter set to the oauth_access_token. https://docs.snowflake.com/en/user-guide/python-connector-example.html#connecting-with-oauth :param authenticator="oauth" :param token="oauth_access_token" """ super().__init__(**kwargs) self.connect_fcn = snowflake.connector.connect self.epoch_to_timestamp_fcn = "TO_TIMESTAMP"
[docs] def test_connection(self, query: Optional[str] = None): """ Test the database connection by executing a simple query. :param query: Optional custom query to test. Defaults to querying Snowflake version. :return: Result of the query execution. """ if not query: query = "SELECT current_version();" return super().test_connection(query)
[docs] @classmethod def get_insert_dml(cls, table_fqn: str, columns: List, records: List[Dict]): if not records: return "", tuple() # Validate table name and columns... cls.validate_identifier([table_fqn] + columns) select_statement = ", ".join([ f"PARSE_JSON(Column{pos + 1}) AS {column}" if isinstance(records[0][column], list) or isinstance(records[0][column], dict) else f"Column{pos + 1} AS {column}" for pos, column in enumerate(columns) ]) # Explicitly specify columns to insert into (important for tables with auto-increment columns) columns_list = ", ".join(columns) # SQL construction is safe: identifiers are validated, values are escaped... query = f""" INSERT INTO {table_fqn} ({columns_list}) SELECT {select_statement} FROM VALUES {', '.join(cls._get_values_statement(columns, records))};""" # nosec B608 return query, tuple()
@classmethod @overload def get_merge_dml( cls, target: str, columns: List[str], pk_ids: List[str], records: List[Dict], source: Optional[str] = None, epoch_column: Optional[str] = None, ) -> Tuple[str, Tuple]: """ Use this one when the source is a table """ @classmethod @overload def get_merge_dml( cls, target: str, columns: List[str], pk_ids: List[str], records: Optional[List[Dict]] = None, source: Optional[str] = None, epoch_column: Optional[str] = None, ) -> Tuple[str, Tuple]: """ Use this one when the source is a list of records """
[docs] @classmethod def get_merge_dml( cls, target: str, columns: List[str], pk_ids: List[str], records: Optional[List[Dict]] = None, source: Optional[str] = None, epoch_column: Optional[str] = None, ) -> Tuple[str, Tuple]: # Validate target table name, pk_ids and columns... cls.validate_identifier([target] + pk_ids + columns) if epoch_column: cls.validate_identifier([epoch_column]) source_key = source or "source" # If source is provided externally, validate it (basic check) if source and source != "source": # For table sources, validate the identifier # Note: This is a simplified check. For complex sources, more validation may be needed. cls.validate_identifier([source_key]) on_sts = " AND ".join([f"{target}.{key} = {source_key}.{key}" for key in pk_ids]) matched_and = f"AND {source_key}.{epoch_column} > {target}.{epoch_column} " if epoch_column else "" all_columns = pk_ids + columns if epoch_column: all_columns.extend([epoch_column]) all_columns = sorted(set(all_columns)) set_statement = [f"{key} = {source_key}.{key}" for key in all_columns if key not in pk_ids] source = source if not source: if not records: return "", tuple() # Use all_columns (sorted) to ensure VALUES match the column order in the source definition # SQL construction is safe: identifiers are validated, values are escaped... source = f"""( SELECT * FROM ( VALUES {', '.join(cls._get_values_statement(all_columns, records))} ) AS source({', '.join(all_columns)}) ) AS source""" # nosec B608 # SQL construction is safe: identifiers are validated, values are escaped... query = f""" MERGE INTO {target} USING {source} ON {on_sts} WHEN MATCHED {matched_and}THEN UPDATE SET {', '.join(set_statement)} WHEN NOT MATCHED THEN INSERT ({', '.join(all_columns)}) VALUES ({', '.join([f'{source_key}.{key}' for key in all_columns])});""" # nosec B608 return query, tuple()
[docs] @classmethod def _get_values_statement(cls, columns: List[str], records: List[Dict]) -> List: """ Generate VALUES clause with properly escaped values. :param columns: List of column names. :param records: List of record dictionaries. :return: List of value tuples as strings. """ values = [] for record in records: tmp = [] for key in columns: value = record[key] if isinstance(value, list) or isinstance(value, dict): # JSON values need to be escaped json_str = json.dumps(value) escaped_json = cls._escape_string_value(json_str) tmp.append(f"'{escaped_json}'") elif isinstance(value, str): # String values need escaping escaped_str = cls._escape_string_value(value) tmp.append(f"'{escaped_str}'") elif value is None: # NULL values tmp.append("NULL") elif isinstance(value, (date, datetime)): # Date and datetime values need to be converted to strings and quoted str_value = str(value) escaped_str = cls._escape_string_value(str_value) tmp.append(f"'{escaped_str}'") else: # Numeric and boolean values (safe, no escaping needed) tmp.append(str(value)) values.append(f"({', '.join(tmp)})") return values