Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 83 additions & 6 deletions mssql_python/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,23 +51,44 @@
INFO_TYPE_STRING_THRESHOLD: int = 10000

# UTF-16 encoding variants that should use SQL_WCHAR by default
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16", "utf-16le", "utf-16be"])
UTF16_ENCODINGS: frozenset[str] = frozenset(["utf-16le", "utf-16be"])


def _validate_encoding(encoding: str) -> bool:
"""
Cached encoding validation using codecs.lookup().

Validate encoding name for security and correctness.

This function performs two-layer validation:
1. Security check: Ensures only safe characters are in the encoding name
2. Codec check: Verifies it's a valid Python codec

Args:
encoding (str): The encoding name to validate.

Returns:
bool: True if encoding is valid, False otherwise.
bool: True if encoding is valid and safe, False otherwise.

Note:
Uses LRU cache to avoid repeated expensive codecs.lookup() calls.
Cache size is limited to 128 entries which should cover most use cases.
Rejects encodings with:
- Empty or too long names (>100 chars)
- Suspicious characters (only alphanumeric, hyphen, underscore, dot allowed)
- Invalid Python codecs
"""
# Type check: encoding must be a string
if not isinstance(encoding, str):
return False

# Security validation: Check length and characters
if not encoding or len(encoding) > 100:
return False

# Only allow safe characters in encoding names
# Valid codec names contain: letters, numbers, hyphens, underscores, dots
for char in encoding:
if not (char.isalnum() or char in ('-', '_', '.')):
return False

# Verify it's a valid Python codec
try:
codecs.lookup(encoding)
return True
Expand Down Expand Up @@ -400,6 +421,19 @@ def setencoding(
# Normalize encoding to casefold for more robust Unicode handling
encoding = encoding.casefold()

# Explicitly reject 'utf-16' with BOM - require explicit endianness
if encoding == 'utf-16' and ctype == ConstantsDDBC.SQL_WCHAR.value:
error_msg = (
"The 'utf-16' codec includes a Byte Order Mark (BOM) which is incompatible with SQL_WCHAR. "
"Use 'utf-16le' (little-endian) or 'utf-16be' (big-endian) instead. "
"SQL Server's NVARCHAR/NCHAR types expect UTF-16LE without BOM."
)
log('error', "Attempted to use 'utf-16' with BOM for SQL_WCHAR")
raise ProgrammingError(
driver_error=error_msg,
ddbc_error=error_msg,
)

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
Expand All @@ -424,6 +458,20 @@ def setencoding(
),
)

# Enforce UTF-16 encoding restriction for SQL_WCHAR
if ctype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
error_msg = (
f"SQL_WCHAR only supports UTF-16 encodings (utf-16, utf-16le, utf-16be). "
f"Encoding '{encoding}' is not compatible with SQL_WCHAR. "
f"Either use a UTF-16 encoding, or use SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) instead."
)
log('error', "Invalid encoding/ctype combination: %s with SQL_WCHAR",
sanitize_user_input(encoding))
raise ProgrammingError(
driver_error=error_msg,
ddbc_error=error_msg,
)

# Store the encoding settings
self._encoding_settings = {"encoding": encoding, "ctype": ctype}

Expand Down Expand Up @@ -543,13 +591,42 @@ def setdecoding(
# Normalize encoding to lowercase for consistency
encoding = encoding.lower()

# Enforce UTF-16 encoding restriction for SQL_WCHAR and SQL_WMETADATA
if (sqltype == ConstantsDDBC.SQL_WCHAR.value or sqltype == SQL_WMETADATA) and encoding not in UTF16_ENCODINGS:
sqltype_name = "SQL_WCHAR" if sqltype == ConstantsDDBC.SQL_WCHAR.value else "SQL_WMETADATA"
error_msg = (
f"{sqltype_name} only supports UTF-16 encodings (utf-16, utf-16le, utf-16be). "
f"Encoding '{encoding}' is not compatible with {sqltype_name}. "
f"Either use a UTF-16 encoding, or use SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) "
f"for the ctype parameter."
)
log('error', "Invalid encoding for %s: %s", sqltype_name, sanitize_user_input(encoding))
raise ProgrammingError(
driver_error=error_msg,
ddbc_error=error_msg,
)

# Set default ctype based on encoding if not provided
if ctype is None:
if encoding in UTF16_ENCODINGS:
ctype = ConstantsDDBC.SQL_WCHAR.value
else:
ctype = ConstantsDDBC.SQL_CHAR.value

# Additional validation: if user explicitly sets ctype to SQL_WCHAR but encoding is not UTF-16
if ctype == ConstantsDDBC.SQL_WCHAR.value and encoding not in UTF16_ENCODINGS:
error_msg = (
f"SQL_WCHAR ctype only supports UTF-16 encodings (utf-16, utf-16le, utf-16be). "
f"Encoding '{encoding}' is not compatible with SQL_WCHAR ctype. "
f"Either use a UTF-16 encoding, or use SQL_CHAR ({ConstantsDDBC.SQL_CHAR.value}) "
f"for the ctype parameter."
)
log('error', "Invalid encoding for SQL_WCHAR ctype: %s", sanitize_user_input(encoding))
raise ProgrammingError(
driver_error=error_msg,
ddbc_error=error_msg,
)

# Validate ctype
valid_ctypes = [ConstantsDDBC.SQL_CHAR.value, ConstantsDDBC.SQL_WCHAR.value]
if ctype not in valid_ctypes:
Expand Down
82 changes: 73 additions & 9 deletions mssql_python/cursor.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from mssql_python.constants import ConstantsDDBC as ddbc_sql_const, SQLTypes
from mssql_python.helpers import check_error, log
from mssql_python import ddbc_bindings
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError
from mssql_python.exceptions import InterfaceError, NotSupportedError, ProgrammingError, OperationalError, DatabaseError
from mssql_python.row import Row
from mssql_python import get_settings

Expand Down Expand Up @@ -287,6 +287,51 @@ def _get_numeric_data(self, param: decimal.Decimal) -> Any:

numeric_data.val = bytes(byte_array)
return numeric_data

def _get_encoding_settings(self):
"""
Get the encoding settings from the connection.

Returns:
dict: A dictionary with 'encoding' and 'ctype' keys, or default settings if not available
"""
if hasattr(self._connection, 'getencoding'):
try:
return self._connection.getencoding()
except (OperationalError, DatabaseError) as db_error:
# Only catch database-related errors, not programming errors
log('warning', f"Failed to get encoding settings from connection due to database error: {db_error}")
return {
'encoding': 'utf-16le',
'ctype': ddbc_sql_const.SQL_WCHAR.value
}

# Return default encoding settings if getencoding is not available
return {
'encoding': 'utf-16le',
'ctype': ddbc_sql_const.SQL_WCHAR.value
}

def _get_decoding_settings(self, sql_type):
"""
Get decoding settings for a specific SQL type.

Args:
sql_type: SQL type constant (SQL_CHAR, SQL_WCHAR, etc.)

Returns:
Dictionary containing the decoding settings.
"""
try:
# Get decoding settings from connection for this SQL type
return self._connection.getdecoding(sql_type)
except (OperationalError, DatabaseError) as db_error:
# Only handle expected database-related errors
log('warning', f"Failed to get decoding settings for SQL type {sql_type} due to database error: {db_error}")
if sql_type == ddbc_sql_const.SQL_WCHAR.value:
return {'encoding': 'utf-16le', 'ctype': ddbc_sql_const.SQL_WCHAR.value}
else:
return {'encoding': 'utf-8', 'ctype': ddbc_sql_const.SQL_CHAR.value}

def _map_sql_type( # pylint: disable=too-many-arguments,too-many-positional-arguments,too-many-locals,too-many-return-statements,too-many-branches
self,
Expand Down Expand Up @@ -1028,6 +1073,9 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state
# Clear any previous messages
self.messages = []

# Getting encoding setting
encoding_settings = self._get_encoding_settings()

# Apply timeout if set (non-zero)
if self._timeout > 0:
try:
Expand Down Expand Up @@ -1100,6 +1148,7 @@ def execute( # pylint: disable=too-many-locals,too-many-branches,too-many-state
parameters_type,
self.is_stmt_prepared,
use_prepare,
encoding_settings
)
# Check return code
try:
Expand Down Expand Up @@ -1897,9 +1946,10 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
processed_parameters.append(processed_row)

# Now transpose the processed parameters
columnwise_params, row_count = self._transpose_rowwise_to_columnwise(
processed_parameters
)
columnwise_params, row_count = self._transpose_rowwise_to_columnwise(processed_parameters)

# Get encoding settings
encoding_settings = self._get_encoding_settings()

# Add debug logging
log(
Expand All @@ -1913,7 +1963,12 @@ def executemany( # pylint: disable=too-many-locals,too-many-branches,too-many-s
)

ret = ddbc_bindings.SQLExecuteMany(
self.hstmt, operation, columnwise_params, parameters_type, row_count
self.hstmt,
operation,
columnwise_params,
parameters_type,
row_count,
encoding_settings
)

# Capture any diagnostic messages after execution
Expand Down Expand Up @@ -1945,11 +2000,14 @@ def fetchone(self) -> Union[None, Row]:
"""
self._check_closed() # Check if the cursor is closed

char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
row_data = []
try:
ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data)

ret = ddbc_bindings.DDBCSQLFetchOne(self.hstmt, row_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))
if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))

Expand Down Expand Up @@ -1997,10 +2055,13 @@ def fetchmany(self, size: Optional[int] = None) -> List[Row]:
if size <= 0:
return []

char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
rows_data = []
try:
_ = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size)
ret = ddbc_bindings.DDBCSQLFetchMany(self.hstmt, rows_data, size, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down Expand Up @@ -2039,10 +2100,13 @@ def fetchall(self) -> List[Row]:
if not self._has_result_set and self.description:
self._reset_rownumber()

char_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_CHAR.value)
wchar_decoding = self._get_decoding_settings(ddbc_sql_const.SQL_WCHAR.value)

# Fetch raw data
rows_data = []
try:
_ = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data)
ret = ddbc_bindings.DDBCSQLFetchAll(self.hstmt, rows_data, char_decoding.get('encoding', 'utf-8'), wchar_decoding.get('encoding', 'utf-16le'))

if self.hstmt:
self.messages.extend(ddbc_bindings.DDBCSQLGetAllDiagRecords(self.hstmt))
Expand Down
Loading
Loading