diff --git a/src/replicate/lib/_files.py b/src/replicate/lib/_files.py index ad49a4c..04a17c3 100644 --- a/src/replicate/lib/_files.py +++ b/src/replicate/lib/_files.py @@ -19,6 +19,30 @@ FileEncodingStrategy = Literal["base64", "url"] +def filter_none_values(obj: Any) -> Any: # noqa: ANN401 + """ + Recursively filter out None values from dictionaries. + + This preserves the legacy behavior where None-valued inputs are removed + before making API requests. + + Args: + obj: The object to filter. + + Returns: + The object with None values removed from all nested dictionaries. + """ + if isinstance(obj, dict): + return { + key: filter_none_values(value) + for key, value in obj.items() # type: ignore[misc] + if value is not None + } + if isinstance(obj, (list, tuple)): + return type(obj)(filter_none_values(item) for item in obj) # type: ignore[arg-type, misc] + return obj + + try: import numpy as np # type: ignore @@ -35,12 +59,15 @@ def encode_json( ) -> Any: # noqa: ANN401 """ Return a JSON-compatible version of the object. + + None values are filtered out from dictionaries to prevent API errors. """ if isinstance(obj, dict): return { key: encode_json(value, client, file_encoding_strategy) for key, value in obj.items() # type: ignore + if value is not None } # type: ignore if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): return [encode_json(value, client, file_encoding_strategy) for value in obj] # type: ignore @@ -70,12 +97,15 @@ async def async_encode_json( ) -> Any: # noqa: ANN401 """ Asynchronously return a JSON-compatible version of the object. + + None values are filtered out from dictionaries to prevent API errors. """ if isinstance(obj, dict): return { key: (await async_encode_json(value, client, file_encoding_strategy)) for key, value in obj.items() # type: ignore + if value is not None } # type: ignore if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)): return [ diff --git a/src/replicate/resources/deployments/predictions.py b/src/replicate/resources/deployments/predictions.py index aa22e7d..0f4bac7 100644 --- a/src/replicate/resources/deployments/predictions.py +++ b/src/replicate/resources/deployments/predictions.py @@ -17,6 +17,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ...lib._files import filter_none_values from ..._base_client import make_request_options from ...types.prediction import Prediction from ...types.deployments import prediction_create_params @@ -176,7 +177,7 @@ def create( f"/deployments/{deployment_owner}/{deployment_name}/predictions", body=maybe_transform( { - "input": input, + "input": filter_none_values(input), "stream": stream, "webhook": webhook, "webhook_events_filter": webhook_events_filter, @@ -342,7 +343,7 @@ async def create( f"/deployments/{deployment_owner}/{deployment_name}/predictions", body=await async_maybe_transform( { - "input": input, + "input": filter_none_values(input), "stream": stream, "webhook": webhook, "webhook_events_filter": webhook_events_filter, diff --git a/src/replicate/resources/trainings.py b/src/replicate/resources/trainings.py index 51ab733..7f9abbd 100644 --- a/src/replicate/resources/trainings.py +++ b/src/replicate/resources/trainings.py @@ -18,6 +18,7 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ..lib._files import filter_none_values from ..pagination import SyncCursorURLPage, AsyncCursorURLPage from .._base_client import AsyncPaginator, make_request_options from ..types.training_get_response import TrainingGetResponse @@ -187,7 +188,7 @@ def create( body=maybe_transform( { "destination": destination, - "input": input, + "input": filter_none_values(input), "webhook": webhook, "webhook_events_filter": webhook_events_filter, }, @@ -573,7 +574,7 @@ async def create( body=await async_maybe_transform( { "destination": destination, - "input": input, + "input": filter_none_values(input), "webhook": webhook, "webhook_events_filter": webhook_events_filter, }, diff --git a/tests/test_filter_none_values.py b/tests/test_filter_none_values.py new file mode 100644 index 0000000..afabfd0 --- /dev/null +++ b/tests/test_filter_none_values.py @@ -0,0 +1,136 @@ +from replicate import Replicate +from replicate.lib._files import encode_json, async_encode_json, filter_none_values + + +def test_filter_none_values_simple_dict(): + """Test that None values are filtered from a simple dictionary.""" + input_dict = {"prompt": "banana", "seed": None, "width": 512} + result = filter_none_values(input_dict) + assert result == {"prompt": "banana", "width": 512} + assert "seed" not in result + + +def test_filter_none_values_nested_dict(): + """Test that None values are filtered from nested dictionaries.""" + input_dict = { + "prompt": "banana", + "config": {"seed": None, "temperature": 0.8, "iterations": None}, + "width": 512, + } + result = filter_none_values(input_dict) + assert result == { + "prompt": "banana", + "config": {"temperature": 0.8}, + "width": 512, + } + assert "seed" not in result["config"] + assert "iterations" not in result["config"] + + +def test_filter_none_values_all_none(): + """Test that a dict with all None values returns an empty dict.""" + input_dict = {"seed": None, "temperature": None} + result = filter_none_values(input_dict) + assert result == {} + + +def test_filter_none_values_empty_dict(): + """Test that an empty dict returns an empty dict.""" + input_dict = {} + result = filter_none_values(input_dict) + assert result == {} + + +def test_filter_none_values_with_list(): + """Test that lists are preserved and None values in dicts within lists are filtered.""" + input_dict = { + "prompts": ["banana", "apple"], + "seeds": [None, 42, None], + "config": {"value": None}, + } + result = filter_none_values(input_dict) + # None values in lists are preserved + assert result == { + "prompts": ["banana", "apple"], + "seeds": [None, 42, None], + "config": {}, + } + + +def test_filter_none_values_with_tuple(): + """Test that tuples are preserved.""" + input_dict = {"coords": (1, None, 3)} + result = filter_none_values(input_dict) + # Tuples are preserved as-is + assert result == {"coords": (1, None, 3)} + + +def test_filter_none_values_non_dict(): + """Test that non-dict values are returned as-is.""" + assert filter_none_values("string") == "string" + assert filter_none_values(42) == 42 + assert filter_none_values(None) is None + assert filter_none_values([1, 2, 3]) == [1, 2, 3] + + +def test_encode_json_filters_none(client: Replicate): + """Test that encode_json filters None values from dicts.""" + input_dict = {"prompt": "banana", "seed": None, "width": 512} + result = encode_json(input_dict, client) + assert result == {"prompt": "banana", "width": 512} + assert "seed" not in result + + +def test_encode_json_nested_none_filtering(client: Replicate): + """Test that encode_json recursively filters None values.""" + input_dict = { + "prompt": "banana", + "config": {"seed": None, "temperature": 0.8}, + "metadata": {"user": "test", "session": None}, + } + result = encode_json(input_dict, client) + assert result == { + "prompt": "banana", + "config": {"temperature": 0.8}, + "metadata": {"user": "test"}, + } + + +async def test_async_encode_json_filters_none(async_client): # type: ignore[no-untyped-def] + """Test that async_encode_json filters None values from dicts.""" + input_dict = {"prompt": "banana", "seed": None, "width": 512} + result = await async_encode_json(input_dict, async_client) # type: ignore[arg-type] + assert result == {"prompt": "banana", "width": 512} + assert "seed" not in result + + +async def test_async_encode_json_nested_none_filtering(async_client): # type: ignore[no-untyped-def] + """Test that async_encode_json recursively filters None values.""" + input_dict = { + "prompt": "banana", + "config": {"seed": None, "temperature": 0.8}, + "metadata": {"user": "test", "session": None}, + } + result = await async_encode_json(input_dict, async_client) # type: ignore[arg-type] + assert result == { + "prompt": "banana", + "config": {"temperature": 0.8}, + "metadata": {"user": "test"}, + } + + +def test_encode_json_preserves_false_and_zero(client: Replicate): + """Test that False and 0 are not filtered out.""" + input_dict = { + "prompt": "banana", + "seed": 0, + "enable_feature": False, + "iterations": None, + } + result = encode_json(input_dict, client) + assert result == { + "prompt": "banana", + "seed": 0, + "enable_feature": False, + } + assert "iterations" not in result