Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 75 additions & 6 deletions cadence/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import socket
import uuid
from datetime import timedelta
from typing import TypedDict, Unpack, Any, cast, Union, Callable
from typing import TypedDict, Unpack, Any, cast, Union

from grpc import ChannelCredentials, Compression
from google.protobuf.duration_pb2 import Duration
Expand All @@ -17,11 +17,14 @@
from cadence.api.v1.service_workflow_pb2 import (
StartWorkflowExecutionRequest,
StartWorkflowExecutionResponse,
SignalWithStartWorkflowExecutionRequest,
SignalWithStartWorkflowExecutionResponse,
)
from cadence.api.v1.common_pb2 import WorkflowType, WorkflowExecution
from cadence.api.v1.tasklist_pb2 import TaskList
from cadence.data_converter import DataConverter, DefaultDataConverter
from cadence.metrics import MetricsEmitter, NoOpMetricsEmitter
from cadence.workflow import WorkflowDefinition


class StartWorkflowOptions(TypedDict, total=False):
Expand Down Expand Up @@ -132,7 +135,7 @@ async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:

def _build_start_workflow_request(
self,
workflow: Union[str, Callable],
workflow: Union[str, WorkflowDefinition],
args: tuple[Any, ...],
options: StartWorkflowOptions,
) -> StartWorkflowExecutionRequest:
Expand All @@ -144,8 +147,8 @@ def _build_start_workflow_request(
if isinstance(workflow, str):
workflow_type_name = workflow
else:
# For callable, use function name or __name__ attribute
workflow_type_name = getattr(workflow, "__name__", str(workflow))
# For WorkflowDefinition, use the name property
workflow_type_name = workflow.name

# Encode input arguments
input_payload = None
Expand Down Expand Up @@ -186,15 +189,15 @@ def _build_start_workflow_request(

async def start_workflow(
self,
workflow: Union[str, Callable],
workflow: Union[str, WorkflowDefinition],
*args,
**options_kwargs: Unpack[StartWorkflowOptions],
) -> WorkflowExecution:
"""
Start a workflow execution asynchronously.

Args:
workflow: Workflow function or workflow type name string
workflow: WorkflowDefinition or workflow type name string
*args: Arguments to pass to the workflow
**options_kwargs: StartWorkflowOptions as keyword arguments

Expand Down Expand Up @@ -229,6 +232,72 @@ async def start_workflow(
except Exception:
raise

async def signal_with_start_workflow(
self,
workflow: Union[str, WorkflowDefinition],
signal_name: str,
signal_input: Any = None,
*args,
**options_kwargs: Unpack[StartWorkflowOptions],
) -> WorkflowExecution:
"""
Signal a workflow execution, starting it if it is not already running.

Args:
workflow: WorkflowDefinition or workflow type name string
signal_name: Name of the signal
signal_input: Input data for the signal
*args: Arguments to pass to the workflow if it needs to be started
**options_kwargs: StartWorkflowOptions as keyword arguments

Returns:
WorkflowExecution with workflow_id and run_id

Raises:
ValueError: If required parameters are missing or invalid
Exception: If the gRPC call fails
"""
# Convert kwargs to StartWorkflowOptions and validate
options = _validate_and_apply_defaults(StartWorkflowOptions(**options_kwargs))

# Build the start workflow request
start_request = self._build_start_workflow_request(workflow, args, options)

# Encode signal input
signal_payload = None
if signal_input is not None:
try:
signal_payload = self.data_converter.to_data([signal_input])
except Exception as e:
raise ValueError(f"Failed to encode signal input: {e}")

# Build the SignalWithStartWorkflowExecution request
request = SignalWithStartWorkflowExecutionRequest(
start_request=start_request,
signal_name=signal_name,
)

if signal_payload:
request.signal_input.CopyFrom(signal_payload)

# Execute the gRPC call
try:
response: SignalWithStartWorkflowExecutionResponse = (
await self.workflow_stub.SignalWithStartWorkflowExecution(request)
)

# Emit metrics if available
if self.metrics_emitter:
# TODO: Add metrics similar to Go client
pass

execution = WorkflowExecution()
execution.workflow_id = start_request.workflow_id
execution.run_id = response.run_id
return execution
except Exception:
raise


def _validate_and_copy_defaults(options: ClientOptions) -> ClientOptions:
if "target" not in options:
Expand Down
17 changes: 10 additions & 7 deletions cadence/workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
from dataclasses import dataclass
from datetime import timedelta
from typing import (
Iterator,
Callable,
TypeVar,
TypedDict,
Type,
cast,
Any,
Optional,
Union,
Iterator,
TypedDict,
TypeVar,
Type,
TYPE_CHECKING,
Unpack,
Any,
)
import inspect

from cadence.client import Client
if TYPE_CHECKING:
from cadence.client import Client

from cadence.data_converter import DataConverter

ResultType = TypeVar("ResultType")
Expand Down Expand Up @@ -178,7 +181,7 @@ class WorkflowContext(ABC):
def info(self) -> WorkflowInfo: ...

@abstractmethod
def client(self) -> Client: ...
def client(self) -> "Client": ...

@abstractmethod
def data_converter(self) -> DataConverter: ...
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"""

import pytest
from unittest.mock import Mock, AsyncMock, patch
from unittest.mock import Mock, patch
from cadence.api.v1.service_worker_pb2 import PollForDecisionTaskResponse
from cadence.api.v1.common_pb2 import Payload, WorkflowExecution, WorkflowType
from cadence.api.v1.history_pb2 import (
Expand Down Expand Up @@ -244,7 +244,7 @@ async def test_extract_workflow_input_deserialization_error(
decision_task = self.create_mock_decision_task()

# Mock data converter to raise an exception
mock_client.data_converter.from_data = AsyncMock(
mock_client.data_converter.from_data = Mock(
side_effect=Exception("Deserialization error")
)

Expand Down
17 changes: 12 additions & 5 deletions tests/cadence/test_client_workflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
)
from cadence.client import Client, StartWorkflowOptions, _validate_and_apply_defaults
from cadence.data_converter import DefaultDataConverter
from cadence.workflow import WorkflowDefinition, WorkflowDefinitionOptions


@pytest.fixture
Expand Down Expand Up @@ -96,11 +97,17 @@ async def test_build_request_with_string_workflow(self, mock_client):
uuid.UUID(request.request_id) # This will raise if not valid UUID

@pytest.mark.asyncio
async def test_build_request_with_callable_workflow(self, mock_client):
"""Test building request with callable workflow."""
async def test_build_request_with_workflow_definition(self, mock_client):
"""Test building request with WorkflowDefinition."""
from cadence import workflow

def test_workflow():
pass
class TestWorkflow:
@workflow.run
async def run(self):
pass

workflow_opts = WorkflowDefinitionOptions(name="test_workflow")
workflow_definition = WorkflowDefinition.wrap(TestWorkflow, workflow_opts)

client = Client(domain="test-domain", target="localhost:7933")

Expand All @@ -110,7 +117,7 @@ def test_workflow():
task_start_to_close_timeout=timedelta(seconds=30),
)

request = client._build_start_workflow_request(test_workflow, (), options)
request = client._build_start_workflow_request(workflow_definition, (), options)

assert request.workflow_type.name == "test_workflow"

Expand Down
84 changes: 84 additions & 0 deletions tests/integration_tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,3 +135,87 @@ async def test_workflow_stub_start_and_describe(helper: CadenceHelper):
assert task_timeout_seconds == task_timeout.total_seconds(), (
f"task_start_to_close_timeout mismatch: expected {task_timeout.total_seconds()}s, got {task_timeout_seconds}s"
)


# trying parametrized test for table test
@pytest.mark.parametrize(
"test_case,workflow_id,start_first,expected_same_run",
[
(
"new_workflow",
"test-workflow-signal-with-start-123",
False,
False,
),
(
"existing_workflow",
"test-workflow-signal-existing-456",
True,
True,
),
],
)
@pytest.mark.usefixtures("helper")
async def test_signal_with_start_workflow(
helper: CadenceHelper,
test_case: str,
workflow_id: str,
start_first: bool,
expected_same_run: bool,
):
"""Test signal_with_start_workflow method.

Test cases:
1. new_workflow: SignalWithStartWorkflow starts a new workflow if it doesn't exist
2. existing_workflow: SignalWithStartWorkflow signals existing workflow without restart
"""
async with helper.client() as client:
workflow_type = f"test-workflow-signal-{test_case}"
task_list_name = f"test-task-list-signal-{test_case}"
execution_timeout = timedelta(minutes=5)
signal_name = "test-signal"
signal_input = {"data": "test-signal-data"}

first_run_id = None
if start_first:
first_execution = await client.start_workflow(
workflow_type,
task_list=task_list_name,
execution_start_to_close_timeout=execution_timeout,
workflow_id=workflow_id,
)
first_run_id = first_execution.run_id

execution = await client.signal_with_start_workflow(
workflow_type,
signal_name,
signal_input,
"arg1",
"arg2",
task_list=task_list_name,
execution_start_to_close_timeout=execution_timeout,
workflow_id=workflow_id,
)

assert execution is not None
assert execution.workflow_id == workflow_id
assert execution.run_id is not None
assert execution.run_id != ""

if expected_same_run:
assert execution.run_id == first_run_id

describe_request = DescribeWorkflowExecutionRequest(
domain=DOMAIN_NAME,
workflow_execution=WorkflowExecution(
workflow_id=execution.workflow_id,
run_id=execution.run_id,
),
)

response = await client.workflow_stub.DescribeWorkflowExecution(
describe_request
)

assert response.workflow_execution_info.type.name == workflow_type
assert response.workflow_execution_info.task_list == task_list_name