Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ node_modules
*.cdk.staging
*cdk.out
*cdk.context.json
*poetry.lock

# ruff
.ruff_cache/
10 changes: 9 additions & 1 deletion awswrangler/athena/_executions.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ def start_query_execution(
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
data_source: str | None = None,
wait: bool = False,
retreive_workgroup_config: bool = True,
) -> str | dict[str, Any]:
"""Start a SQL Query against AWS Athena.

Expand Down Expand Up @@ -114,6 +115,11 @@ def start_query_execution(
Data Source / Catalog name. If None, 'AwsDataCatalog' will be used by default.
wait
Indicates whether to wait for the query to finish and return a dictionary with the query execution response.
retreive_workgroup_config
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this flag might be redundant. We should let the user override s3_output, encryption, kms key (which we already do), otherwise fall back to default workgroup settings.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Your suggestion is definitely an option - it just means that we will never get the settings from the AWS API. Personally I'm fine with that, and ready to change the PR to do so, let me know if it's fine

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I created a 2nd PR to this effect: #3237

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, looking into it. I'm closing this PR in the meantime.

Indicates whether to use the workgroup configuration for the query execution.
If True, the workgroup configuration will be retreived and used to determine the s3 output location, encryption, and kms key.
If False, the s3 output location, encryption, and kms key will not be set and will be determined by the AWS Athena service.
Default is True.

Returns
-------
Expand Down Expand Up @@ -149,7 +155,9 @@ def start_query_execution(
query_execution_id = cache_info.query_execution_id
_logger.debug("Valid cache found. Retrieving...")
else:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
wg_config: _WorkGroupConfig = _get_workgroup_config(
session=boto3_session, workgroup=workgroup, retreive_workgroup_config=retreive_workgroup_config
)
query_execution_id = _start_query_execution(
sql=sql,
wg_config=wg_config,
Expand Down
3 changes: 3 additions & 0 deletions awswrangler/athena/_executions.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ def start_query_execution(
athena_query_wait_polling_delay: float = ...,
data_source: str | None = ...,
wait: Literal[False] = ...,
retreive_workgroup_config: bool = ...,
) -> str: ...
@overload
def start_query_execution(
Expand All @@ -42,6 +43,7 @@ def start_query_execution(
athena_query_wait_polling_delay: float = ...,
data_source: str | None = ...,
wait: Literal[True],
retreive_workgroup_config: bool = ...,
) -> dict[str, Any]: ...
@overload
def start_query_execution(
Expand All @@ -60,6 +62,7 @@ def start_query_execution(
athena_query_wait_polling_delay: float = ...,
data_source: str | None = ...,
wait: bool,
retreive_workgroup_config: bool = ...,
) -> str | dict[str, Any]: ...
def stop_query_execution(query_execution_id: str, boto3_session: boto3.Session | None = ...) -> None: ...
def wait_query(
Expand Down
32 changes: 30 additions & 2 deletions awswrangler/athena/_read.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@
pyarrow_additional_kwargs: dict[str, Any] | None = None,
execution_params: list[str] | None = None,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
retreive_workgroup_config: bool = True,
) -> pd.DataFrame | Iterator[pd.DataFrame]:
ctas_query_info: dict[str, str | _QueryMetadata] = create_ctas_table(
sql=sql,
Expand All @@ -339,6 +340,7 @@
boto3_session=boto3_session,
params=execution_params,
paramstyle="qmark",
retreive_workgroup_config=retreive_workgroup_config,
)
fully_qualified_name: str = f'"{ctas_query_info["ctas_database"]}"."{ctas_query_info["ctas_table"]}"'
ctas_query_metadata = cast(_QueryMetadata, ctas_query_info["ctas_query_metadata"])
Expand Down Expand Up @@ -379,6 +381,7 @@
pyarrow_additional_kwargs: dict[str, Any] | None = None,
execution_params: list[str] | None = None,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
retreive_workgroup_config: bool = True,
) -> pd.DataFrame | Iterator[pd.DataFrame]:
query_metadata = _unload(
sql=sql,
Expand All @@ -395,6 +398,7 @@
data_source=data_source,
athena_query_wait_polling_delay=athena_query_wait_polling_delay,
execution_params=execution_params,
retreive_workgroup_config=retreive_workgroup_config,
)
if file_format == "PARQUET":
return _fetch_parquet_result(
Expand Down Expand Up @@ -430,8 +434,11 @@
result_reuse_configuration: dict[str, Any] | None = None,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
client_request_token: str | None = None,
retreive_workgroup_config: bool = True,
) -> pd.DataFrame | Iterator[pd.DataFrame]:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
wg_config: _WorkGroupConfig = _get_workgroup_config(
session=boto3_session, workgroup=workgroup, retreive_workgroup_config=retreive_workgroup_config
)
s3_output = _get_s3_output(s3_output=s3_output, wg_config=wg_config, boto3_session=boto3_session)
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
_logger.debug("Executing sql: %s", sql)
Expand Down Expand Up @@ -496,6 +503,7 @@
result_reuse_configuration: dict[str, Any] | None = None,
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
client_request_token: str | None = None,
retreive_workgroup_config: bool = True,
) -> pd.DataFrame | Iterator[pd.DataFrame]:
"""
Execute a query in Athena and returns results as DataFrame, back to `read_sql_query`.
Expand Down Expand Up @@ -530,6 +538,7 @@
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
execution_params=execution_params,
dtype_backend=dtype_backend,
retreive_workgroup_config=retreive_workgroup_config,
)
finally:
catalog.delete_table_if_exists(database=ctas_database or database, table=name, boto3_session=boto3_session)
Expand Down Expand Up @@ -558,6 +567,7 @@
pyarrow_additional_kwargs=pyarrow_additional_kwargs,
execution_params=execution_params,
dtype_backend=dtype_backend,
retreive_workgroup_config=retreive_workgroup_config,
)
return _resolve_query_without_cache_regular(
sql=sql,
Expand All @@ -578,6 +588,7 @@
result_reuse_configuration=result_reuse_configuration,
dtype_backend=dtype_backend,
client_request_token=client_request_token,
retreive_workgroup_config=retreive_workgroup_config,
)


Expand All @@ -596,8 +607,11 @@
data_source: str | None,
athena_query_wait_polling_delay: float,
execution_params: list[str] | None,
retreive_workgroup_config: bool = True,
) -> _QueryMetadata:
wg_config: _WorkGroupConfig = _get_workgroup_config(session=boto3_session, workgroup=workgroup)
wg_config: _WorkGroupConfig = _get_workgroup_config(
session=boto3_session, workgroup=workgroup, retreive_workgroup_config=retreive_workgroup_config
)
s3_output: str = _get_s3_output(s3_output=path, wg_config=wg_config, boto3_session=boto3_session)
s3_output = s3_output[:-1] if s3_output[-1] == "/" else s3_output
# Athena does not enforce a Query Result Location for UNLOAD. Thus, the workgroup output location
Expand Down Expand Up @@ -767,7 +781,7 @@
@_utils.validate_distributed_kwargs(
unsupported_kwargs=["boto3_session", "s3_additional_kwargs"],
)
def read_sql_query(

Check failure on line 784 in awswrangler/athena/_read.py

View workflow job for this annotation

GitHub Actions / Check (3.9)

Ruff (PLR0913)

awswrangler/athena/_read.py:784:5: PLR0913 Too many arguments in function definition (26 > 25)
sql: str,
database: str,
ctas_approach: bool = True,
Expand All @@ -793,6 +807,7 @@
dtype_backend: Literal["numpy_nullable", "pyarrow"] = "numpy_nullable",
s3_additional_kwargs: dict[str, Any] | None = None,
pyarrow_additional_kwargs: dict[str, Any] | None = None,
retreive_workgroup_config: bool = True,
) -> pd.DataFrame | Iterator[pd.DataFrame]:
"""Execute any SQL query on AWS Athena and return the results as a Pandas DataFrame.

Expand Down Expand Up @@ -1002,6 +1017,11 @@
Forwarded to `to_pandas` method converting from PyArrow tables to Pandas DataFrame.
Valid values include "split_blocks", "self_destruct", "ignore_metadata".
e.g. pyarrow_additional_kwargs={'split_blocks': True}.
retreive_workgroup_config
Indicates whether to use the workgroup configuration for the query execution.
If True, the workgroup configuration will be retreived and used to determine the s3 output location, encryption, and kms key.
If False, the s3 output location, encryption, and kms key will not be set and will be determined by the AWS Athena service.
Default is True.

Returns
-------
Expand Down Expand Up @@ -1120,6 +1140,7 @@
result_reuse_configuration=result_reuse_configuration,
dtype_backend=dtype_backend,
client_request_token=client_request_token,
retreive_workgroup_config=retreive_workgroup_config,
)


Expand Down Expand Up @@ -1386,6 +1407,7 @@
params: dict[str, Any] | list[str] | None = None,
paramstyle: Literal["qmark", "named"] = "named",
athena_query_wait_polling_delay: float = _QUERY_WAIT_POLLING_DELAY,
retreive_workgroup_config: bool = True,
) -> _QueryMetadata:
"""Write query results from a SELECT statement to the specified data format using UNLOAD.

Expand Down Expand Up @@ -1442,6 +1464,11 @@
- ``qmark``
athena_query_wait_polling_delay
Interval in seconds for how often the function will check if the Athena query has completed.
retreive_workgroup_config
Indicates whether to use the workgroup configuration for the query execution.
If True, the workgroup configuration will be retreived and used to determine the s3 output location, encryption, and kms key.
If False, the s3 output location, encryption, and kms key will not be set and will be determined by the AWS Athena service.
Default is True.

Returns
-------
Expand Down Expand Up @@ -1473,4 +1500,5 @@
boto3_session=boto3_session,
data_source=data_source,
execution_params=execution_params,
retreive_workgroup_config=retreive_workgroup_config,
)
11 changes: 11 additions & 0 deletions awswrangler/athena/_read.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ def read_sql_query(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> pd.DataFrame: ...
@overload
def read_sql_query(
Expand Down Expand Up @@ -105,6 +106,7 @@ def read_sql_query(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> Iterator[pd.DataFrame]: ...
@overload
def read_sql_query(
Expand Down Expand Up @@ -132,6 +134,7 @@ def read_sql_query(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
@overload
def read_sql_query(
Expand Down Expand Up @@ -159,6 +162,7 @@ def read_sql_query(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> Iterator[pd.DataFrame]: ...
@overload
def read_sql_query(
Expand Down Expand Up @@ -186,6 +190,7 @@ def read_sql_query(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
@overload
def read_sql_table(
Expand All @@ -210,6 +215,7 @@ def read_sql_table(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> pd.DataFrame: ...
@overload
def read_sql_table(
Expand All @@ -234,6 +240,7 @@ def read_sql_table(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> Iterator[pd.DataFrame]: ...
@overload
def read_sql_table(
Expand All @@ -258,6 +265,7 @@ def read_sql_table(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
@overload
def read_sql_table(
Expand All @@ -282,6 +290,7 @@ def read_sql_table(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> Iterator[pd.DataFrame]: ...
@overload
def read_sql_table(
Expand All @@ -306,6 +315,7 @@ def read_sql_table(
dtype_backend: Literal["numpy_nullable", "pyarrow"] = ...,
s3_additional_kwargs: dict[str, Any] | None = ...,
pyarrow_additional_kwargs: dict[str, Any] | None = ...,
retreive_workgroup_config: bool = ...,
) -> pd.DataFrame | Iterator[pd.DataFrame]: ...
def unload(
sql: str,
Expand All @@ -323,4 +333,5 @@ def unload(
params: dict[str, Any] | list[str] | None = ...,
paramstyle: Literal["qmark", "named"] = ...,
athena_query_wait_polling_delay: float = ...,
retreive_workgroup_config: bool = ...,
) -> _QueryMetadata: ...
Loading
Loading