Source code for core_db.interfaces.sql_based

# -*- coding: utf-8 -*-

import json
import re
from abc import ABC, abstractmethod
from typing import Any
from typing import Dict
from typing import Iterable
from typing import Iterator
from typing import List
from typing import Optional
from typing import Tuple
from typing import overload

from core_mixins.utils import get_batches

from core_db.interfaces.base import DatabaseClientException
from .base import IDatabaseClient


[docs] class ISqlDatabaseClient(IDatabaseClient, ABC): """ Abstract base class for SQL-based database clients. This class extends IDatabaseClient to provide SQL-specific functionality including parameterized query execution, CRUD operations, and batch data manipulation. It implements security measures to prevent SQL injection and provides standardized methods for common database operations. Key Features: ------------- - **Parameterized queries**: All DML methods use placeholders to prevent SQL injection. - **Batch operations**: Efficient batch inserts with configurable chunk sizes. - **Type mapping**: Python-to-SQL type conversion for DDL generation. - **Column validation**: Automatic validation of column names against injection patterns. - **Context manager support**: Automatic commit and connection cleanup. Usage: ------ .. code-block:: python class MyDatabaseClient(ISqlDatabaseClient): PLACEHOLDER = "?" # Override for database-specific placeholder def __init__(self, **kwargs): super().__init__(**kwargs) self.connect_fcn = my_driver.connect @classmethod def get_merge_dml(cls, table_fqn, pk_ids, columns, records): # Implement database-specific MERGE/UPSERT logic pass # Use the client with MyDatabaseClient(host="localhost", database="mydb") as client: # Insert records client.insert_records( table_fqn="users", columns=["name", "email"], records=[{"name": "Alice", "email": "alice@example.com"}] ) # Query data client.select("users", columns=["name", "email"]) for record in client.fetch_records(): print(record) .. See Also: --------- - core_db.engines: Concrete implementations for specific databases """ # Mapper for python types to database types... TYPE_MAPPER = { int: "INTEGER", float: "DOUBLE", str: "TEXT", bool: "BOOLEAN", dict: "JSON", list: "JSON", } # Valid identifier to validate column names # preventing SQL injection... VALID_IDENTIFIER = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$") # Each database engine could have its own # symbol, override if required... PLACEHOLDER = "%s"
[docs] def __init__(self, **kwargs): """ Initialize SQL database client with connection parameters. :param kwargs: Database-specific connection parameters (e.g., host, port, database, user, password). """ super().__init__(**kwargs) # Function used by the Database Engine # to convert to timestamp... self.epoch_to_timestamp_fcn = None
[docs] def test_connection(self, query: Any = None): """ Test the database connection by executing a version query. :param query: Optional custom query to test connection. Defaults to version query. :return: Query execution result. :raises DatabaseClientException: If connection test fails. """ try: return self.execute(query or "SELECT version() AS version;") except Exception as error: raise DatabaseClientException(error)
[docs] def execute(self, query: Any, **kwargs): """ Execute a SQL query. :param query: SQL query string to execute. :param kwargs: Additional keyword arguments. :return: Cursor execution result. :raises DatabaseClientException: If there is no active connection or execution fails. """ if not self.cxn: raise DatabaseClientException("There is not an active connection!") try: if not self.cursor: self.cursor = self.cxn.cursor() return self._execute(query, **kwargs) except Exception as error: raise DatabaseClientException(error)
[docs] def _execute(self, query: Any, **kwargs): """ Internal method for executing queries with database-specific parameter handling. Override this method in subclasses if the database driver requires specific parameter passing conventions (e.g., positional vs keyword arguments). :param query: SQL query string to execute. :param kwargs: Additional keyword arguments including 'params' for parameter binding. :return: Cursor execution result. """ if not self.cursor: raise DatabaseClientException("No active cursor!") return self.cursor.execute(query, **kwargs)
[docs] def commit(self) -> None: """ Commit the current transaction to persist changes. :raises DatabaseClientException: If no active connection exists. """ if not self.cxn: raise DatabaseClientException("No active connection!") self.cxn.commit()
[docs] def select(self, table_fqn: str, columns: Optional[List[str]] = None): """ Execute a SELECT query on the specified table. :param table_fqn: Table's fully qualified name. :param columns: List of column names to select. If None, selects all columns (*). :return: Cursor execution result. :raises ValueError: If column names contain invalid characters. """ return self.execute(self.get_select_ddl(table_fqn, columns))
[docs] @classmethod def validate_identifier(cls, identifiers: Iterable[str]) -> None: """ Validate table or column names to prevent SQL injection attacks. Checks that all identifiers match the pattern for valid SQL identifiers: - Must start with letter or underscore - Can contain only alphanumeric characters and underscores - Full qualified names can contain a dot. :param identifiers: Iterable of identifiers like column name strings to validate. :raises ValueError: If any identifier contains invalid characters. """ for identifier in identifiers: if not cls.VALID_IDENTIFIER.match(identifier): raise ValueError( f"Invalid identifier: '{identifier}'. " "Identifiers must start with a letter or underscore and contain only " "alphanumeric characters, underscores, and dots (for qualified names)." )
[docs] @staticmethod def _escape_string_value(value: str) -> str: """ Escape string values to prevent SQL injection. Escapes single quotes by doubling them (SQL standard). :param value: The string value to escape. :return: Escaped string value. """ # Replace single quotes with double single quotes (SQL standard escaping) return value.replace("'", "''")
[docs] @classmethod def get_select_ddl( cls, table_fqn: str, columns: Optional[List[str]] = None, ) -> str: """ Returns the DDL statement for a select. :param table_fqn: Table's fully qualified name. :param columns: List of column names to select. If None, selects all columns (*). :return: SELECT SQL statement. :raises ValueError: If column names contain invalid characters (potential SQL injection). """ if columns: cls.validate_identifier(columns) column_list = ", ".join(columns) else: column_list = "*" # SQL construction is safe: columns are validated, table FQN is validated... return f"SELECT {column_list} FROM {table_fqn}" # nosec B608
[docs] def columns(self): """ Get column names from the current cursor. :return: List of column names, or empty list if cursor is None. """ if self.cursor and self.cursor.description: return [x[0].lower() for x in self.cursor.description] # type: ignore[union-attr] return []
[docs] def fetch_record(self) -> Dict[str, Any]: """ Fetch a single record as a dictionary with column names as keys. :return: Dictionary with column names as keys and row values, or None if no record. """ res = self.fetch_one() return dict(zip(self.columns(), res)) if res else {}
[docs] def fetch_one(self) -> Tuple: """ Fetch a single record as a tuple. :return: Tuple containing row values. """ if not self.cursor: raise DatabaseClientException("No active cursor!") row = self.cursor.fetchone() return tuple(row) if row is not None else row
[docs] def fetch_records(self) -> Iterator[Dict[str, Any]]: """ Fetch all records as an iterator of dictionaries. Converts fetchall tuples into dictionaries with column names as keys. :return: Iterator yielding dictionaries with column names as keys. """ headers = self.columns() for row in self.fetch_all(): yield dict(zip(headers, row))
[docs] def fetch_all(self) -> Iterator[Tuple]: """ Fetch all records as an iterator of tuples. :return: Iterator yielding tuples containing row values. """ if self.cursor: rows = self.cursor.fetchall() if rows: for row in rows: # type: ignore[union-attr] yield row
[docs] @classmethod def get_create_table_ddl( cls, table_fqn: str, columns: List[Tuple[str, Any]], temporal: bool = False, primary_keys: Optional[List[str]] = None, unique_columns: Optional[List[str]] = None, not_null_columns: Optional[List[str]] = None, ) -> str: """ Generate the SQL CREATE TABLE statement. :param table_fqn: Table's fully qualified name. :param columns: List of tuples defining the column name and data type. :param temporal: Whether to create a temporary table. Defaults to False. :param primary_keys: Column names to include in the PRIMARY KEY constraint. :param unique_columns: Column names to include in the UNIQUE constraint. :param not_null_columns: Column names that should have a NOT NULL constraint. :return: The CREATE TABLE SQL statement. """ col_names = [name for name, _ in columns] cls.validate_identifier(col_names) not_null_set = set(not_null_columns or []) parts = [ f"{name} {cls.TYPE_MAPPER.get(type_, 'VARCHAR')}{' NOT NULL' if name in not_null_set else ''}" for name, type_ in columns ] if primary_keys: parts.append(f"PRIMARY KEY ({', '.join(primary_keys)})") if unique_columns: parts.append(f"UNIQUE ({', '.join(unique_columns)})") columns_def = ", ".join(parts) return f"CREATE{' TEMPORARY' if temporal else ''} TABLE {table_fqn} ({columns_def});"
[docs] def insert_records( self, table_fqn: str, columns: List[str], records: List[Dict], records_per_request: int = 500, ) -> int: """ Insert a batch of records into a table using parameterized queries. Automatically manages batching to avoid memory issues with large datasets. :param table_fqn: Table's fully qualified name (FQN). :param columns: List of column names to insert into. :param records: List of dictionaries representing records to insert. :param records_per_request: Number of records to insert per batch. Defaults to 500. :return: Total number of inserted records. :raises DatabaseClientException: If insertion fails. """ if records: try: total = 0 for chunk_ in get_batches(records, records_per_request): query, params = self.get_insert_dml(table_fqn, columns, chunk_) self.execute(query, params=params) if self.cursor: total += self.cursor.rowcount return total except Exception as error: raise DatabaseClientException(error) return 0
[docs] @classmethod def get_insert_dml( cls, table_fqn: str, columns: List[str], records: List[Dict], ) -> Tuple[str, tuple]: """ Generate a parameterized INSERT statement with multi-row VALUES. Uses parameter binding to prevent SQL injection attacks. :param table_fqn: Table's fully qualified name (FQN). :param columns: List of column names to insert into. :param records: List of dictionaries representing records to insert. :return: Tuple of (query string with placeholders, flattened parameter tuple). :raises ValueError: If column names contain invalid characters. """ if not records: return "", tuple() cls.validate_identifier(columns) placeholders = ", ".join([cls.PLACEHOLDER for _ in columns]) values_rows = ", ".join([f"({placeholders})" for _ in records]) # SQL construction is safe: columns are validated, values use placeholders... query = f"INSERT INTO {table_fqn} ({', '.join(columns)}) VALUES {values_rows}" # nosec B608 # Extracting and flatten parameters in the correct order... params: List = [] for record in records: params.extend([ json.dumps(record[col]) if type(record[col]) in [dict, list] else record[col] for col in columns ]) return query, tuple(params)
@classmethod @overload def get_delete_dml( cls, table_fqn: str, *, pk_id: Optional[str] = None, ids: Optional[List] = None, ) -> Tuple[str, Tuple]: """Generate DELETE statement with primary key IN clause.""" @classmethod @overload def get_delete_dml( cls, table_fqn: str, *, pk_id: Optional[str] = None, conditionals: Optional[List[Dict]] = None, ) -> Tuple[str, Tuple]: """Generate DELETE statement with multiple conditional clauses."""
[docs] @classmethod def get_delete_dml( cls, table_fqn: str, *, pk_id: Optional[str] = None, ids: Optional[List] = None, conditionals: Optional[List[Dict]] = None, ) -> Tuple[str, Tuple]: """ Generate a parameterized DELETE statement with placeholders. Uses parameter binding to prevent SQL injection attacks. :param table_fqn: Table's fully qualified name. :param pk_id: Primary key column name for IN clause deletion. :param ids: List of ID values to delete (used with pk_id). :param conditionals: List of dictionaries with conditional criteria for WHERE clause. :return: Tuple of (query string with placeholders, list of parameter values). """ if pk_id: if conditionals: values = tuple(rec[pk_id] for rec in conditionals) else: values = tuple(ids) if ids else tuple() if not values: return "", tuple() placeholders = ", ".join([cls.PLACEHOLDER for _ in values]) # SQL construction is safe: pk_id is validated, values use placeholders... query = f"DELETE FROM {table_fqn} WHERE {pk_id} IN ({placeholders})" # nosec B608 return query, values if not conditionals: return "", tuple() condition_parts, params = cls._get_conditional_statements(conditionals) # SQL construction is safe: columns validated in _get_conditional_statements, # values use placeholders... query = f"DELETE FROM {table_fqn} WHERE {' OR '.join(condition_parts)}" # nosec B608 return query, tuple(params)
[docs] @classmethod def _get_conditional_statements( cls, conditionals: Optional[List[Dict]] = None, ) -> Tuple[List, List]: """ Generate parameterized WHERE clause components from conditional dictionaries. Each dictionary in conditionals represents an OR condition, with keys as column names and values as comparison values. Keys within a dictionary are combined with AND. :param conditionals: List of dictionaries with column:value pairs. :return: Tuple of (list of condition strings, list of parameter values). Example: -------- >>> conditionals = [{"name": "Alice", "age": 30}, {"status": "active"}] >>> # Generates: WHERE (name = ? AND age = ?) OR (status = ?) """ condition_parts = [] params = [] if conditionals: for conditional in conditionals: keys = list(conditional.keys()) cls.validate_identifier(keys) condition_part = " AND ".join([f"{key} = {cls.PLACEHOLDER}" for key in keys]) condition_parts.append(f"({condition_part})") params.extend([conditional[key] for key in keys]) return condition_parts, params
[docs] @classmethod @abstractmethod def get_merge_dml(cls, *args, **kwargs) -> Tuple[str, Tuple]: """ Generate the MERGE/UPSERT statement. This is an abstract method that must be implemented by concrete classes, as each database engine may require specific syntax for merge operations. :param args: Positional arguments specific to the implementation. :param kwargs: Keyword arguments specific to the implementation. :return: Tuple of (MERGE/UPSERT statement string, tuple of parameters). """
[docs] def close(self) -> None: """ Close the database connection after committing pending changes. This method automatically commits any pending transactions before closing the connection to ensure data persistence. """ if self.cxn: self.commit() super().close()