Source code for core_db.engines.oracle

# -*- coding: utf-8 -*-

"""
Oracle Database Client Module
===============================

This module provides the OracleClient class for connecting to and interacting
with Oracle databases using the oracledb library.
"""

import json
import re
from datetime import datetime
from typing import Any
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple

import oracledb
from core_mixins.utils import get_batches

from core_db.interfaces.base import DatabaseClientException
from core_db.interfaces.sql_based import ISqlDatabaseClient


[docs] class OracleClient(ISqlDatabaseClient): """ Client for Oracle connection... =================================================== How to use =================================================== .. code-block:: python from core_db.engines.oracle import OracleClient with OracleClient(user="...", password="...", dsn=f"{host}:{port}/{service_name}") as client: res = client.execute("SELECT * FROM ...") for x in client.fetch_all(): print(x) .. """
[docs] def __init__(self, **kwargs) -> None: """ Expected -> user, password, dsn... More information: - https://oracle.github.io/python-oracledb/ - https://python-oracledb.readthedocs.io/en/latest/index.html """ super().__init__(**kwargs) self.connect_fcn = oracledb.connect
[docs] def test_connection(self, query: Optional[str] = None): """ Test the database connection by executing a simple query. :param query: Optional custom query to test. Defaults to querying Oracle version. :return: Result of the query execution. """ if not query: query = 'SELECT * FROM "V$VERSION"' return super().test_connection(query)
[docs] @staticmethod def _convert_value(value: Any) -> Any: """ Convert Python values to Oracle-compatible types. :param value: The value to convert. :return: Converted value. """ if type(value) in [dict, list]: return json.dumps(value) elif isinstance(value, str): # Trying to parse date strings in ISO format... if re.match(r'^\d{4}-\d{2}-\d{2}$', value): try: return datetime.strptime(value, '%Y-%m-%d').date() except ValueError: return value elif re.match(r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$', value): try: return datetime.strptime(value, '%Y-%m-%d %H:%M:%S') except ValueError: return value return value
[docs] def _execute(self, query: Any, **kwargs): """ Override execute to handle Oracle's parameter format requirements, because Oracle's oracledb driver expects parameters as a list passed as the second positional argument, not as a keyword argument. """ if not self.cursor: raise DatabaseClientException("No active cursor!") params = kwargs.pop("params", None) if params: # Converting tuple to list and apply value conversions... if isinstance(params, tuple): params = [self._convert_value(p) for p in params] elif isinstance(params, list): params = [self._convert_value(p) for p in params] return self.cursor.execute(query, params, **kwargs) return self.cursor.execute(query, **kwargs)
[docs] def insert_records( self, table_fqn: str, columns: List[str], records: List[Dict], records_per_request: int = 500, ) -> int: """ Insert records using Oracle's executemany for better performance and proper type handling. :param table_fqn: Table's fully qualified name (FQN). :param columns: List of column names to insert into. :param records: List of dictionaries representing records to insert. :param records_per_request: Number of records to insert per batch. :return: Total number of inserted records. :raises DatabaseClientException: If insertion fails. """ if not records: return 0 if not self.cxn: raise DatabaseClientException("There is not an active connection!") try: # Ensure cursor exists if not self.cursor: self.cursor = self.cxn.cursor() self.validate_identifier(columns) # Building single-row INSERT with placeholders... placeholders = ", ".join([f":{i + 1}" for i in range(len(columns))]) # SQL construction is safe: columns are validated, values use placeholders... query = f"INSERT INTO {table_fqn} ({', '.join(columns)}) VALUES ({placeholders})" # nosec B608 total = 0 for chunk in get_batches(records, records_per_request): # Convert records to list of lists for executemany params_list = [] for record in chunk: row: List[Any] = [] for col in columns: value = record[col] # Handle different data types if type(value) in [dict, list]: row.append(json.dumps(value)) elif isinstance(value, str): # Try to parse date strings in ISO format if re.match(r'^\d{4}-\d{2}-\d{2}$', value): try: row.append(datetime.strptime(value, '%Y-%m-%d').date()) except ValueError: row.append(value) elif re.match(r'^\d{4}-\d{2}-\d{2} \d{2}:\d{2}:\d{2}$', value): try: row.append(datetime.strptime(value, '%Y-%m-%d %H:%M:%S')) except ValueError: row.append(value) else: row.append(value) else: row.append(value) params_list.append(row) self.cursor.executemany(query, params_list) total += self.cursor.rowcount return total except Exception as error: raise DatabaseClientException(error) from error
[docs] @classmethod def _get_conditional_statements( cls, conditionals: Optional[List[Dict]] = None, ) -> Tuple[List, List]: """ Helper function to generate the conditions and params and reuse it into other implementations. Override if required by a specific engine. """ condition_parts = [] param_counter = 1 params = [] if conditionals: for conditional in conditionals: keys = list(conditional.keys()) cls.validate_identifier(keys) condition_part = " AND ".join([f"{key} = :{param_counter + i}" for i, key in enumerate(keys)]) condition_parts.append(f"({condition_part})") params.extend([conditional[key] for key in keys]) param_counter += len(keys) return condition_parts, params
[docs] @classmethod def get_insert_dml( cls, table_fqn: str, columns: List[str], records: List[Dict], ) -> Tuple[str, Tuple]: """ Generate a parameterized INSERT statement for Oracle. Uses Oracle's named parameter syntax (:1, :2, etc.) and parameter binding to prevent SQL injection attacks. :param table_fqn: Table's fully qualified name (FQN). :param columns: List of column names to insert into. :param records: List of dictionaries representing records to insert. :return: Tuple of (query string with placeholders, flattened parameter tuple). :raises ValueError: If column names contain invalid characters. """ if not records: return "", tuple() cls.validate_identifier(columns) insert_statements = [] # Building multi-row INSERT ALL statement with Oracle syntax... for i, record in enumerate(records): offset = i * len(columns) placeholders = ", ".join([f":{offset + j + 1}" for j in range(len(columns))]) insert_statements.append(f"INTO {table_fqn} ({', '.join(columns)}) VALUES ({placeholders})") # SQL construction is safe: columns are validated, values use placeholders... query = "INSERT ALL\n " + "\n ".join(insert_statements) + "\nSELECT 1 FROM DUAL" # nosec B608 # Extract and flatten parameters in the correct order params = [] for record in records: params.extend([ json.dumps(record[col]) if type(record[col]) in [dict, list] else record[col] for col in columns ]) return query, tuple(params)
[docs] @classmethod def get_merge_dml( cls, table_fqn: str, pk_ids: List[str], columns: List[str], records: List[Dict], ) -> Tuple[str, Tuple]: """ Generate parameterized MERGE statement for Oracle. Uses parameter binding to prevent SQL injection attacks. :param table_fqn: Table's fully qualified name. :param pk_ids: List of primary key column names. :param columns: List of column names. :param records: List of dictionaries representing records. :return: Tuple of (query string with placeholders, list of parameter tuples). :raises ValueError: If column names contain invalid characters. """ if not records: return "", tuple() cls.validate_identifier(columns + pk_ids) # Column aliases for the source (first record) source_columns = ", ".join([f":{i + 1} AS {col}" for i, col in enumerate(columns)]) # Building the single-row source for the first record - SQL construction is safe: columns validated, values use placeholders first_select = f"SELECT {source_columns} FROM DUAL" # nosec B608 # Building the USING clause with UNION ALL for multiple records if len(records) > 1: union_selects = [] for i in range(1, len(records)): offset = i * len(columns) union_columns = ", ".join([f":{offset + j + 1} AS {col}" for j, col in enumerate(columns)]) union_selects.append(f"SELECT {union_columns} FROM DUAL") # nosec B608 using_clause = f"({first_select} UNION ALL {' UNION ALL '.join(union_selects)})" else: using_clause = f"({first_select})" # Building the ON clause for matching on_conditions = " AND ".join([f"target.{pk} = source.{pk}" for pk in pk_ids]) # Building UPDATE SET statement update_columns = [col for col in columns if col not in pk_ids] set_statement = ", ".join([f"target.{col} = source.{col}" for col in update_columns]) # Building INSERT statement insert_columns = ", ".join(columns) insert_values = ", ".join([f"source.{col}" for col in columns]) # SQL construction is safe: columns are validated, values use placeholders... query = f""" MERGE INTO {table_fqn} target USING {using_clause} source ON ({on_conditions}) WHEN MATCHED THEN UPDATE SET {set_statement} WHEN NOT MATCHED THEN INSERT ({insert_columns}) VALUES ({insert_values})""" # nosec B608 # Extracting parameters in the correct order for each record... params: List = [] for record in records: params.extend([ json.dumps(record[col]) if type(record[col]) in [dict, list] else record[col] for col in columns ]) return query, tuple(params)