Skip to content

Commit ec8bf54

Browse files
committed
feat: filter None values from inputs before API requests
This PR filters out None-valued inputs from all prediction and training creation methods before making API requests. This preserves legacy behavior and prevents potential API errors when users pass None values. Changes: - Add filter_none_values() utility function to recursively remove None values from dictionaries - Update encode_json() and async_encode_json() to filter None values when processing dicts - Apply filtering to deployments.predictions.create() and trainings.create() methods - Add comprehensive test suite for None filtering functionality Fixes https://linear.app/replicate/issue/DP-737
1 parent f778511 commit ec8bf54

File tree

4 files changed

+168
-4
lines changed

4 files changed

+168
-4
lines changed

src/replicate/lib/_files.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,26 @@
1919
FileEncodingStrategy = Literal["base64", "url"]
2020

2121

22+
def filter_none_values(obj: Any) -> Any: # noqa: ANN401
23+
"""
24+
Recursively filter out None values from dictionaries.
25+
26+
This preserves the legacy behavior where None-valued inputs are removed
27+
before making API requests.
28+
29+
Args:
30+
obj: The object to filter.
31+
32+
Returns:
33+
The object with None values removed from all nested dictionaries.
34+
"""
35+
if isinstance(obj, dict):
36+
return {key: filter_none_values(value) for key, value in obj.items() if value is not None}
37+
if isinstance(obj, (list, tuple)):
38+
return type(obj)(filter_none_values(item) for item in obj)
39+
return obj
40+
41+
2242
try:
2343
import numpy as np # type: ignore
2444

@@ -35,12 +55,15 @@ def encode_json(
3555
) -> Any: # noqa: ANN401
3656
"""
3757
Return a JSON-compatible version of the object.
58+
59+
None values are filtered out from dictionaries to prevent API errors.
3860
"""
3961

4062
if isinstance(obj, dict):
4163
return {
4264
key: encode_json(value, client, file_encoding_strategy)
4365
for key, value in obj.items() # type: ignore
66+
if value is not None
4467
} # type: ignore
4568
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
4669
return [encode_json(value, client, file_encoding_strategy) for value in obj] # type: ignore
@@ -70,12 +93,15 @@ async def async_encode_json(
7093
) -> Any: # noqa: ANN401
7194
"""
7295
Asynchronously return a JSON-compatible version of the object.
96+
97+
None values are filtered out from dictionaries to prevent API errors.
7398
"""
7499

75100
if isinstance(obj, dict):
76101
return {
77102
key: (await async_encode_json(value, client, file_encoding_strategy))
78103
for key, value in obj.items() # type: ignore
104+
if value is not None
79105
} # type: ignore
80106
if isinstance(obj, (list, set, frozenset, GeneratorType, tuple)):
81107
return [

src/replicate/resources/deployments/predictions.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
async_to_raw_response_wrapper,
1818
async_to_streamed_response_wrapper,
1919
)
20+
from ...lib._files import filter_none_values
2021
from ..._base_client import make_request_options
2122
from ...types.prediction import Prediction
2223
from ...types.deployments import prediction_create_params
@@ -176,7 +177,7 @@ def create(
176177
f"/deployments/{deployment_owner}/{deployment_name}/predictions",
177178
body=maybe_transform(
178179
{
179-
"input": input,
180+
"input": filter_none_values(input),
180181
"stream": stream,
181182
"webhook": webhook,
182183
"webhook_events_filter": webhook_events_filter,
@@ -342,7 +343,7 @@ async def create(
342343
f"/deployments/{deployment_owner}/{deployment_name}/predictions",
343344
body=await async_maybe_transform(
344345
{
345-
"input": input,
346+
"input": filter_none_values(input),
346347
"stream": stream,
347348
"webhook": webhook,
348349
"webhook_events_filter": webhook_events_filter,

src/replicate/resources/trainings.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
async_to_raw_response_wrapper,
1919
async_to_streamed_response_wrapper,
2020
)
21+
from ..lib._files import filter_none_values
2122
from ..pagination import SyncCursorURLPage, AsyncCursorURLPage
2223
from .._base_client import AsyncPaginator, make_request_options
2324
from ..types.training_get_response import TrainingGetResponse
@@ -187,7 +188,7 @@ def create(
187188
body=maybe_transform(
188189
{
189190
"destination": destination,
190-
"input": input,
191+
"input": filter_none_values(input),
191192
"webhook": webhook,
192193
"webhook_events_filter": webhook_events_filter,
193194
},
@@ -573,7 +574,7 @@ async def create(
573574
body=await async_maybe_transform(
574575
{
575576
"destination": destination,
576-
"input": input,
577+
"input": filter_none_values(input),
577578
"webhook": webhook,
578579
"webhook_events_filter": webhook_events_filter,
579580
},

tests/test_filter_none_values.py

Lines changed: 136 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,136 @@
1+
from replicate import Replicate
2+
from replicate.lib._files import encode_json, async_encode_json, filter_none_values
3+
4+
5+
def test_filter_none_values_simple_dict():
6+
"""Test that None values are filtered from a simple dictionary."""
7+
input_dict = {"prompt": "banana", "seed": None, "width": 512}
8+
result = filter_none_values(input_dict)
9+
assert result == {"prompt": "banana", "width": 512}
10+
assert "seed" not in result
11+
12+
13+
def test_filter_none_values_nested_dict():
14+
"""Test that None values are filtered from nested dictionaries."""
15+
input_dict = {
16+
"prompt": "banana",
17+
"config": {"seed": None, "temperature": 0.8, "iterations": None},
18+
"width": 512,
19+
}
20+
result = filter_none_values(input_dict)
21+
assert result == {
22+
"prompt": "banana",
23+
"config": {"temperature": 0.8},
24+
"width": 512,
25+
}
26+
assert "seed" not in result["config"]
27+
assert "iterations" not in result["config"]
28+
29+
30+
def test_filter_none_values_all_none():
31+
"""Test that a dict with all None values returns an empty dict."""
32+
input_dict = {"seed": None, "temperature": None}
33+
result = filter_none_values(input_dict)
34+
assert result == {}
35+
36+
37+
def test_filter_none_values_empty_dict():
38+
"""Test that an empty dict returns an empty dict."""
39+
input_dict = {}
40+
result = filter_none_values(input_dict)
41+
assert result == {}
42+
43+
44+
def test_filter_none_values_with_list():
45+
"""Test that lists are preserved and None values in dicts within lists are filtered."""
46+
input_dict = {
47+
"prompts": ["banana", "apple"],
48+
"seeds": [None, 42, None],
49+
"config": {"value": None},
50+
}
51+
result = filter_none_values(input_dict)
52+
# None values in lists are preserved
53+
assert result == {
54+
"prompts": ["banana", "apple"],
55+
"seeds": [None, 42, None],
56+
"config": {},
57+
}
58+
59+
60+
def test_filter_none_values_with_tuple():
61+
"""Test that tuples are preserved."""
62+
input_dict = {"coords": (1, None, 3)}
63+
result = filter_none_values(input_dict)
64+
# Tuples are preserved as-is
65+
assert result == {"coords": (1, None, 3)}
66+
67+
68+
def test_filter_none_values_non_dict():
69+
"""Test that non-dict values are returned as-is."""
70+
assert filter_none_values("string") == "string"
71+
assert filter_none_values(42) == 42
72+
assert filter_none_values(None) is None
73+
assert filter_none_values([1, 2, 3]) == [1, 2, 3]
74+
75+
76+
def test_encode_json_filters_none(client: Replicate):
77+
"""Test that encode_json filters None values from dicts."""
78+
input_dict = {"prompt": "banana", "seed": None, "width": 512}
79+
result = encode_json(input_dict, client)
80+
assert result == {"prompt": "banana", "width": 512}
81+
assert "seed" not in result
82+
83+
84+
def test_encode_json_nested_none_filtering(client: Replicate):
85+
"""Test that encode_json recursively filters None values."""
86+
input_dict = {
87+
"prompt": "banana",
88+
"config": {"seed": None, "temperature": 0.8},
89+
"metadata": {"user": "test", "session": None},
90+
}
91+
result = encode_json(input_dict, client)
92+
assert result == {
93+
"prompt": "banana",
94+
"config": {"temperature": 0.8},
95+
"metadata": {"user": "test"},
96+
}
97+
98+
99+
async def test_async_encode_json_filters_none(async_client):
100+
"""Test that async_encode_json filters None values from dicts."""
101+
input_dict = {"prompt": "banana", "seed": None, "width": 512}
102+
result = await async_encode_json(input_dict, async_client)
103+
assert result == {"prompt": "banana", "width": 512}
104+
assert "seed" not in result
105+
106+
107+
async def test_async_encode_json_nested_none_filtering(async_client):
108+
"""Test that async_encode_json recursively filters None values."""
109+
input_dict = {
110+
"prompt": "banana",
111+
"config": {"seed": None, "temperature": 0.8},
112+
"metadata": {"user": "test", "session": None},
113+
}
114+
result = await async_encode_json(input_dict, async_client)
115+
assert result == {
116+
"prompt": "banana",
117+
"config": {"temperature": 0.8},
118+
"metadata": {"user": "test"},
119+
}
120+
121+
122+
def test_encode_json_preserves_false_and_zero(client: Replicate):
123+
"""Test that False and 0 are not filtered out."""
124+
input_dict = {
125+
"prompt": "banana",
126+
"seed": 0,
127+
"enable_feature": False,
128+
"iterations": None,
129+
}
130+
result = encode_json(input_dict, client)
131+
assert result == {
132+
"prompt": "banana",
133+
"seed": 0,
134+
"enable_feature": False,
135+
}
136+
assert "iterations" not in result

0 commit comments

Comments
 (0)