Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 30 additions & 0 deletions src/replicate/lib/_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,30 @@
FileEncodingStrategy = Literal["base64", "url"]


def filter_none_values(obj: Any) -> Any: # noqa: ANN401
"""
Recursively filter out None values from dictionaries.

This preserves the legacy behavior where None-valued inputs are removed
before making API requests.

Args:
obj: The object to filter.

Returns:
The object with None values removed from all nested dictionaries.
"""
if isinstance(obj, dict):
return {
key: filter_none_values(value)
for key, value in obj.items() # type: ignore[misc]
if value is not None
}
if isinstance(obj, (list, tuple)):
return type(obj)(filter_none_values(item) for item in obj) # type: ignore[arg-type, misc]
return obj


try:
import numpy as np # type: ignore

Expand All @@ -35,12 +59,15 @@ def encode_json(
) -> Any: # noqa: ANN401
"""
Return a JSON-compatible version of the object.

None values are filtered out from dictionaries to prevent API errors.
"""

if isinstance(obj, dict):
return {
key: encode_json(value, client, file_encoding_strategy)
for key, value in obj.items() # type: ignore
if value is not None
} # type: ignore
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [encode_json(value, client, file_encoding_strategy) for value in obj] # type: ignore
Expand Down Expand Up @@ -70,12 +97,15 @@ async def async_encode_json(
) -> Any: # noqa: ANN401
"""
Asynchronously return a JSON-compatible version of the object.

None values are filtered out from dictionaries to prevent API errors.
"""

if isinstance(obj, dict):
return {
key: (await async_encode_json(value, client, file_encoding_strategy))
for key, value in obj.items() # type: ignore
if value is not None
} # type: ignore
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
return [
Expand Down
5 changes: 3 additions & 2 deletions src/replicate/resources/deployments/predictions.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from ...lib._files import filter_none_values
from ..._base_client import make_request_options
from ...types.prediction import Prediction
from ...types.deployments import prediction_create_params
Expand Down Expand Up @@ -176,7 +177,7 @@ def create(
f"/deployments/{deployment_owner}/{deployment_name}/predictions",
body=maybe_transform(
{
"input": input,
"input": filter_none_values(input),
"stream": stream,
"webhook": webhook,
"webhook_events_filter": webhook_events_filter,
Expand Down Expand Up @@ -342,7 +343,7 @@ async def create(
f"/deployments/{deployment_owner}/{deployment_name}/predictions",
body=await async_maybe_transform(
{
"input": input,
"input": filter_none_values(input),
"stream": stream,
"webhook": webhook,
"webhook_events_filter": webhook_events_filter,
Expand Down
5 changes: 3 additions & 2 deletions src/replicate/resources/trainings.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
async_to_raw_response_wrapper,
async_to_streamed_response_wrapper,
)
from ..lib._files import filter_none_values
from ..pagination import SyncCursorURLPage, AsyncCursorURLPage
from .._base_client import AsyncPaginator, make_request_options
from ..types.training_get_response import TrainingGetResponse
Expand Down Expand Up @@ -187,7 +188,7 @@ def create(
body=maybe_transform(
{
"destination": destination,
"input": input,
"input": filter_none_values(input),
"webhook": webhook,
"webhook_events_filter": webhook_events_filter,
},
Expand Down Expand Up @@ -573,7 +574,7 @@ async def create(
body=await async_maybe_transform(
{
"destination": destination,
"input": input,
"input": filter_none_values(input),
"webhook": webhook,
"webhook_events_filter": webhook_events_filter,
},
Expand Down
136 changes: 136 additions & 0 deletions tests/test_filter_none_values.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
from replicate import Replicate
from replicate.lib._files import encode_json, async_encode_json, filter_none_values


def test_filter_none_values_simple_dict():
"""Test that None values are filtered from a simple dictionary."""
input_dict = {"prompt": "banana", "seed": None, "width": 512}
result = filter_none_values(input_dict)
assert result == {"prompt": "banana", "width": 512}
assert "seed" not in result


def test_filter_none_values_nested_dict():
"""Test that None values are filtered from nested dictionaries."""
input_dict = {
"prompt": "banana",
"config": {"seed": None, "temperature": 0.8, "iterations": None},
"width": 512,
}
result = filter_none_values(input_dict)
assert result == {
"prompt": "banana",
"config": {"temperature": 0.8},
"width": 512,
}
assert "seed" not in result["config"]
assert "iterations" not in result["config"]


def test_filter_none_values_all_none():
"""Test that a dict with all None values returns an empty dict."""
input_dict = {"seed": None, "temperature": None}
result = filter_none_values(input_dict)
assert result == {}


def test_filter_none_values_empty_dict():
"""Test that an empty dict returns an empty dict."""
input_dict = {}
result = filter_none_values(input_dict)
assert result == {}


def test_filter_none_values_with_list():
"""Test that lists are preserved and None values in dicts within lists are filtered."""
input_dict = {
"prompts": ["banana", "apple"],
"seeds": [None, 42, None],
"config": {"value": None},
}
result = filter_none_values(input_dict)
# None values in lists are preserved
assert result == {
"prompts": ["banana", "apple"],
"seeds": [None, 42, None],
"config": {},
}


def test_filter_none_values_with_tuple():
"""Test that tuples are preserved."""
input_dict = {"coords": (1, None, 3)}
result = filter_none_values(input_dict)
# Tuples are preserved as-is
assert result == {"coords": (1, None, 3)}


def test_filter_none_values_non_dict():
"""Test that non-dict values are returned as-is."""
assert filter_none_values("string") == "string"
assert filter_none_values(42) == 42
assert filter_none_values(None) is None
assert filter_none_values([1, 2, 3]) == [1, 2, 3]


def test_encode_json_filters_none(client: Replicate):
"""Test that encode_json filters None values from dicts."""
input_dict = {"prompt": "banana", "seed": None, "width": 512}
result = encode_json(input_dict, client)
assert result == {"prompt": "banana", "width": 512}
assert "seed" not in result


def test_encode_json_nested_none_filtering(client: Replicate):
"""Test that encode_json recursively filters None values."""
input_dict = {
"prompt": "banana",
"config": {"seed": None, "temperature": 0.8},
"metadata": {"user": "test", "session": None},
}
result = encode_json(input_dict, client)
assert result == {
"prompt": "banana",
"config": {"temperature": 0.8},
"metadata": {"user": "test"},
}


async def test_async_encode_json_filters_none(async_client): # type: ignore[no-untyped-def]
"""Test that async_encode_json filters None values from dicts."""
input_dict = {"prompt": "banana", "seed": None, "width": 512}
result = await async_encode_json(input_dict, async_client) # type: ignore[arg-type]
assert result == {"prompt": "banana", "width": 512}
assert "seed" not in result


async def test_async_encode_json_nested_none_filtering(async_client): # type: ignore[no-untyped-def]
"""Test that async_encode_json recursively filters None values."""
input_dict = {
"prompt": "banana",
"config": {"seed": None, "temperature": 0.8},
"metadata": {"user": "test", "session": None},
}
result = await async_encode_json(input_dict, async_client) # type: ignore[arg-type]
assert result == {
"prompt": "banana",
"config": {"temperature": 0.8},
"metadata": {"user": "test"},
}


def test_encode_json_preserves_false_and_zero(client: Replicate):
"""Test that False and 0 are not filtered out."""
input_dict = {
"prompt": "banana",
"seed": 0,
"enable_feature": False,
"iterations": None,
}
result = encode_json(input_dict, client)
assert result == {
"prompt": "banana",
"seed": 0,
"enable_feature": False,
}
assert "iterations" not in result