11#!/usr/bin/env python
2- # -*- coding: utf-8; -*-
3-
4- # Copyright (c) 2024 Oracle and/or its affiliates.
2+ # Copyright (c) 2024, 2025 Oracle and/or its affiliates.
53# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
64
75import logging
1210import oci
1311from oci import Signer
1412from tqdm .auto import tqdm
13+
1514from ads .common .oci_datascience import OCIDataScienceMixin
1615
1716logger = logging .getLogger (__name__ )
2019DEFAULT_WAIT_TIME = 1200
2120DEFAULT_POLL_INTERVAL = 10
2221WORK_REQUEST_PERCENTAGE = 100
23- # default tqdm progress bar format:
22+ # default tqdm progress bar format:
2423# {l_bar}{bar}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, ' '{rate_fmt}{postfix}]
2524# customize the bar format to remove the {n_fmt}/{total_fmt} from the right side
26- DEFAULT_BAR_FORMAT = ' {l_bar}{bar}| [{elapsed}<{remaining}, ' ' {rate_fmt}{postfix}]'
25+ DEFAULT_BAR_FORMAT = " {l_bar}{bar}| [{elapsed}<{remaining}, " " {rate_fmt}{postfix}]"
2726
2827
2928class DataScienceWorkRequest (OCIDataScienceMixin ):
@@ -32,13 +31,13 @@ class DataScienceWorkRequest(OCIDataScienceMixin):
3231 """
3332
3433 def __init__ (
35- self ,
36- id : str ,
34+ self ,
35+ id : str ,
3736 description : str = "Processing" ,
38- config : dict = None ,
39- signer : Signer = None ,
40- client_kwargs : dict = None ,
41- ** kwargs
37+ config : dict = None ,
38+ signer : Signer = None ,
39+ client_kwargs : dict = None ,
40+ ** kwargs ,
4241 ) -> None :
4342 """Initializes ADSWorkRequest object.
4443
@@ -49,41 +48,43 @@ def __init__(
4948 description: str
5049 Progress bar initial step description (Defaults to `Processing`).
5150 config : dict, optional
52- OCI API key config dictionary to initialize
51+ OCI API key config dictionary to initialize
5352 oci.data_science.DataScienceClient (Defaults to None).
5453 signer : oci.signer.Signer, optional
55- OCI authentication signer to initialize
54+ OCI authentication signer to initialize
5655 oci.data_science.DataScienceClient (Defaults to None).
5756 client_kwargs : dict, optional
58- Additional client keyword arguments to initialize
57+ Additional client keyword arguments to initialize
5958 oci.data_science.DataScienceClient (Defaults to None).
6059 kwargs:
61- Additional keyword arguments to initialize
60+ Additional keyword arguments to initialize
6261 oci.data_science.DataScienceClient.
6362 """
6463 self .id = id
6564 self ._description = description
6665 self ._percentage = 0
6766 self ._status = None
67+ self ._error_message = ""
6868 super ().__init__ (config , signer , client_kwargs , ** kwargs )
69-
7069
7170 def _sync (self ):
7271 """Fetches the latest work request information to ADSWorkRequest object."""
7372 work_request = self .client .get_work_request (self .id ).data
74- work_request_logs = self .client .list_work_request_logs (
75- self .id
76- ).data
73+ work_request_logs = self .client .list_work_request_logs (self .id ).data
7774
78- self ._percentage = work_request .percent_complete
75+ self ._percentage = work_request .percent_complete
7976 self ._status = work_request .status
80- self ._description = work_request_logs [- 1 ].message if work_request_logs else "Processing"
77+ self ._description = (
78+ work_request_logs [- 1 ].message if work_request_logs else "Processing"
79+ )
80+ if work_request .status == "FAILED" :
81+ self ._error_message = self .client .list_work_request_errors (self .id ).data
8182
8283 def watch (
83- self ,
84+ self ,
8485 progress_callback : Callable ,
85- max_wait_time : int = DEFAULT_WAIT_TIME ,
86- poll_interval : int = DEFAULT_POLL_INTERVAL ,
86+ max_wait_time : int = DEFAULT_WAIT_TIME ,
87+ poll_interval : int = DEFAULT_POLL_INTERVAL ,
8788 ):
8889 """Updates the progress bar with realtime message and percentage until the process is completed.
8990
@@ -92,10 +93,10 @@ def watch(
9293 progress_callback: Callable
9394 Progress bar callback function.
9495 It must accept `(percent_change, description)` where `percent_change` is the
95- work request percent complete and `description` is the latest work request log message.
96+ work request percent complete and `description` is the latest work request log message.
9697 max_wait_time: int
9798 Maximum amount of time to wait in seconds (Defaults to 1200).
98- Negative implies infinite wait time.
99+ Negative implies infinite wait time.
99100 poll_interval: int
100101 Poll interval in seconds (Defaults to 10).
101102
@@ -107,7 +108,6 @@ def watch(
107108
108109 start_time = time .time ()
109110 while self ._percentage < 100 :
110-
111111 seconds_since = time .time () - start_time
112112 if max_wait_time > 0 and seconds_since >= max_wait_time :
113113 logger .error (f"Exceeded max wait time of { max_wait_time } seconds." )
@@ -124,12 +124,14 @@ def watch(
124124 percent_change = self ._percentage - previous_percent_complete
125125 previous_percent_complete = self ._percentage
126126 progress_callback (
127- percent_change = percent_change ,
128- description = self ._description
127+ percent_change = percent_change , description = self ._description
129128 )
130129
131130 if self ._status in WORK_REQUEST_STOP_STATE :
132- if self ._status != oci .work_requests .models .WorkRequest .STATUS_SUCCEEDED :
131+ if (
132+ self ._status
133+ != oci .work_requests .models .WorkRequest .STATUS_SUCCEEDED
134+ ):
133135 if self ._description :
134136 raise Exception (self ._description )
135137 else :
@@ -145,12 +147,12 @@ def watch(
145147
146148 def wait_work_request (
147149 self ,
148- progress_bar_description : str = "Processing" ,
149- max_wait_time : int = DEFAULT_WAIT_TIME ,
150- poll_interval : int = DEFAULT_POLL_INTERVAL
150+ progress_bar_description : str = "Processing" ,
151+ max_wait_time : int = DEFAULT_WAIT_TIME ,
152+ poll_interval : int = DEFAULT_POLL_INTERVAL ,
151153 ):
152154 """Waits for the work request progress bar to be completed.
153-
155+
154156 Parameters
155157 ----------
156158 progress_bar_description: str
@@ -160,7 +162,7 @@ def wait_work_request(
160162 Negative implies infinite wait time.
161163 poll_interval: int
162164 Poll interval in seconds (Defaults to 10).
163-
165+
164166 Returns
165167 -------
166168 None
@@ -172,7 +174,7 @@ def wait_work_request(
172174 mininterval = 0 ,
173175 file = sys .stdout ,
174176 desc = progress_bar_description ,
175- bar_format = DEFAULT_BAR_FORMAT
177+ bar_format = DEFAULT_BAR_FORMAT ,
176178 ) as pbar :
177179
178180 def progress_callback (percent_change , description ):
@@ -184,6 +186,5 @@ def progress_callback(percent_change, description):
184186 self .watch (
185187 progress_callback = progress_callback ,
186188 max_wait_time = max_wait_time ,
187- poll_interval = poll_interval
189+ poll_interval = poll_interval ,
188190 )
189-
0 commit comments