diff --git a/cadence/client.py b/cadence/client.py index a75d7b5..ef140fd 100644 --- a/cadence/client.py +++ b/cadence/client.py @@ -2,7 +2,7 @@ import socket import uuid from datetime import timedelta -from typing import TypedDict, Unpack, Any, cast, Union, Callable +from typing import TypedDict, Unpack, Any, cast, Union from grpc import ChannelCredentials, Compression from google.protobuf.duration_pb2 import Duration @@ -17,11 +17,14 @@ from cadence.api.v1.service_workflow_pb2 import ( StartWorkflowExecutionRequest, StartWorkflowExecutionResponse, + SignalWithStartWorkflowExecutionRequest, + SignalWithStartWorkflowExecutionResponse, ) from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution from cadence.api.v1.tasklist_pb2 import TaskList from cadence.data_converter import DataConverter, DefaultDataConverter from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter +from cadence.workflow import WorkflowDefinition class StartWorkflowOptions(TypedDict, total=False): @@ -132,7 +135,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None: def _build_start_workflow_request( self, - workflow: Union[str, Callable], + workflow: Union[str, WorkflowDefinition], args: tuple[Any, ...], options: StartWorkflowOptions, ) -> StartWorkflowExecutionRequest: @@ -144,8 +147,8 @@ def _build_start_workflow_request( if isinstance(workflow, str): workflow_type_name = workflow else: - # For callable, use function name or __name__ attribute - workflow_type_name = getattr(workflow, "__name__", str(workflow)) + # For WorkflowDefinition, use the name property + workflow_type_name = workflow.name # Encode input arguments input_payload = None @@ -186,7 +189,7 @@ def _build_start_workflow_request( async def start_workflow( self, - workflow: Union[str, Callable], + workflow: Union[str, WorkflowDefinition], *args, **options_kwargs: Unpack[StartWorkflowOptions], ) -> WorkflowExecution: @@ -194,7 +197,7 @@ async def start_workflow( Start a workflow execution asynchronously. Args: - workflow: Workflow function or workflow type name string + workflow: WorkflowDefinition or workflow type name string *args: Arguments to pass to the workflow **options_kwargs: StartWorkflowOptions as keyword arguments @@ -229,6 +232,72 @@ async def start_workflow( except Exception: raise + async def signal_with_start_workflow( + self, + workflow: Union[str, WorkflowDefinition], + signal_name: str, + signal_input: Any = None, + *args, + **options_kwargs: Unpack[StartWorkflowOptions], + ) -> WorkflowExecution: + """ + Signal a workflow execution, starting it if it is not already running. + + Args: + workflow: WorkflowDefinition or workflow type name string + signal_name: Name of the signal + signal_input: Input data for the signal + *args: Arguments to pass to the workflow if it needs to be started + **options_kwargs: StartWorkflowOptions as keyword arguments + + Returns: + WorkflowExecution with workflow_id and run_id + + Raises: + ValueError: If required parameters are missing or invalid + Exception: If the gRPC call fails + """ + # Convert kwargs to StartWorkflowOptions and validate + options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs)) + + # Build the start workflow request + start_request = self._build_start_workflow_request(workflow, args, options) + + # Encode signal input + signal_payload = None + if signal_input is not None: + try: + signal_payload = self.data_converter.to_data([signal_input]) + except Exception as e: + raise ValueError(f"Failed to encode signal input: {e}") + + # Build the SignalWithStartWorkflowExecution request + request = SignalWithStartWorkflowExecutionRequest( + start_request=start_request, + signal_name=signal_name, + ) + + if signal_payload: + request.signal_input.CopyFrom(signal_payload) + + # Execute the gRPC call + try: + response: SignalWithStartWorkflowExecutionResponse = ( + await self.workflow_stub.SignalWithStartWorkflowExecution(request) + ) + + # Emit metrics if available + if self.metrics_emitter: + # TODO: Add metrics similar to Go client + pass + + execution = WorkflowExecution() + execution.workflow_id = start_request.workflow_id + execution.run_id = response.run_id + return execution + except Exception: + raise + def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions: if "target" not in options: diff --git a/cadence/workflow.py b/cadence/workflow.py index 913ebd1..de8791e 100644 --- a/cadence/workflow.py +++ b/cadence/workflow.py @@ -4,20 +4,23 @@ from dataclasses import dataclass from datetime import timedelta from typing import ( + Iterator, Callable, + TypeVar, + TypedDict, + Type, cast, + Any, Optional, Union, - Iterator, - TypedDict, - TypeVar, - Type, + TYPE_CHECKING, Unpack, - Any, ) import inspect -from cadence.client import Client +if TYPE_CHECKING: + from cadence.client import Client + from cadence.data_converter import DataConverter ResultType = TypeVar("ResultType") @@ -178,7 +181,7 @@ class WorkflowContext(ABC): def info(self) -> WorkflowInfo: ... @abstractmethod - def client(self) -> Client: ... + def client(self) -> "Client": ... @abstractmethod def data_converter(self) -> DataConverter: ... diff --git a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py index 3aa5d44..c40a1e6 100644 --- a/tests/cadence/_internal/workflow/test_workflow_engine_integration.py +++ b/tests/cadence/_internal/workflow/test_workflow_engine_integration.py @@ -4,7 +4,7 @@ """ import pytest -from unittest.mock import Mock, AsyncMock, patch +from unittest.mock import Mock, patch from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType from cadence.api.v1.history_pb2 import ( @@ -244,7 +244,7 @@ async def test_extract_workflow_input_deserialization_error( decision_task = self.create_mock_decision_task() # Mock data converter to raise an exception - mock_client.data_converter.from_data = AsyncMock( + mock_client.data_converter.from_data = Mock( side_effect=Exception("Deserialization error") ) diff --git a/tests/cadence/test_client_workflow.py b/tests/cadence/test_client_workflow.py index cdf7a2c..acb1a98 100644 --- a/tests/cadence/test_client_workflow.py +++ b/tests/cadence/test_client_workflow.py @@ -10,6 +10,7 @@ ) from cadence.client import Client, StartWorkflowOptions, _validate_and_apply_defaults from cadence.data_converter import DefaultDataConverter +from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions @pytest.fixture @@ -96,11 +97,17 @@ async def test_build_request_with_string_workflow(self, mock_client): uuid.UUID(request.request_id) # This will raise if not valid UUID @pytest.mark.asyncio - async def test_build_request_with_callable_workflow(self, mock_client): - """Test building request with callable workflow.""" + async def test_build_request_with_workflow_definition(self, mock_client): + """Test building request with WorkflowDefinition.""" + from cadence import workflow - def test_workflow(): - pass + class TestWorkflow: + @workflow.run + async def run(self): + pass + + workflow_opts = WorkflowDefinitionOptions(name="test_workflow") + workflow_definition = WorkflowDefinition.wrap(TestWorkflow, workflow_opts) client = Client(domain="test-domain", target="localhost:7933") @@ -110,7 +117,7 @@ def test_workflow(): task_start_to_close_timeout=timedelta(seconds=30), ) - request = client._build_start_workflow_request(test_workflow, (), options) + request = client._build_start_workflow_request(workflow_definition, (), options) assert request.workflow_type.name == "test_workflow" diff --git a/tests/integration_tests/test_client.py b/tests/integration_tests/test_client.py index 5b4e785..1a0e31c 100644 --- a/tests/integration_tests/test_client.py +++ b/tests/integration_tests/test_client.py @@ -135,3 +135,87 @@ async def test_workflow_stub_start_and_describe(helper: CadenceHelper): assert task_timeout_seconds == task_timeout.total_seconds(), ( f"task_start_to_close_timeout mismatch: expected {task_timeout.total_seconds()}s, got {task_timeout_seconds}s" ) + + +# trying parametrized test for table test +@pytest.mark.parametrize( + "test_case,workflow_id,start_first,expected_same_run", + [ + ( + "new_workflow", + "test-workflow-signal-with-start-123", + False, + False, + ), + ( + "existing_workflow", + "test-workflow-signal-existing-456", + True, + True, + ), + ], +) +@pytest.mark.usefixtures("helper") +async def test_signal_with_start_workflow( + helper: CadenceHelper, + test_case: str, + workflow_id: str, + start_first: bool, + expected_same_run: bool, +): + """Test signal_with_start_workflow method. + + Test cases: + 1. new_workflow: SignalWithStartWorkflow starts a new workflow if it doesn't exist + 2. existing_workflow: SignalWithStartWorkflow signals existing workflow without restart + """ + async with helper.client() as client: + workflow_type = f"test-workflow-signal-{test_case}" + task_list_name = f"test-task-list-signal-{test_case}" + execution_timeout = timedelta(minutes=5) + signal_name = "test-signal" + signal_input = {"data": "test-signal-data"} + + first_run_id = None + if start_first: + first_execution = await client.start_workflow( + workflow_type, + task_list=task_list_name, + execution_start_to_close_timeout=execution_timeout, + workflow_id=workflow_id, + ) + first_run_id = first_execution.run_id + + execution = await client.signal_with_start_workflow( + workflow_type, + signal_name, + signal_input, + "arg1", + "arg2", + task_list=task_list_name, + execution_start_to_close_timeout=execution_timeout, + workflow_id=workflow_id, + ) + + assert execution is not None + assert execution.workflow_id == workflow_id + assert execution.run_id is not None + assert execution.run_id != "" + + if expected_same_run: + assert execution.run_id == first_run_id + + describe_request = DescribeWorkflowExecutionRequest( + domain=DOMAIN_NAME, + workflow_execution=WorkflowExecution( + workflow_id=execution.workflow_id, + run_id=execution.run_id, + ), + ) + + response = await client.workflow_stub.DescribeWorkflowExecution( + describe_request + ) + + assert response.workflow_execution_info.type.name == workflow_type + assert response.workflow_execution_info.task_list == task_list_name