Skip to content

Commit e4393fa

Browse files
Fix overflow and dtype handling in rgblike_to_depthmap (NumPy + PyTorch) (#12546)
* Fix overflow in rgblike_to_depthmap by safe dtype casting (torch & NumPy) * Fix: store original dtype and cast back after safe computation * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent b3e9dfc commit e4393fa

File tree

1 file changed

+31
-8
lines changed

1 file changed

+31
-8
lines changed

src/diffusers/image_processor.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1045,16 +1045,39 @@ def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) ->
10451045
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
10461046
r"""
10471047
Convert an RGB-like depth image to a depth map.
1048+
"""
1049+
# 1. Cast the tensor to a larger integer type (e.g., int32)
1050+
# to safely perform the multiplication by 256.
1051+
# 2. Perform the 16-bit combination: High-byte * 256 + Low-byte.
1052+
# 3. Cast the final result to the desired depth map type (uint16) if needed
1053+
# before returning, though leaving it as int32/int64 is often safer
1054+
# for return value from a library function.
1055+
1056+
if isinstance(image, torch.Tensor):
1057+
# Cast to a safe dtype (e.g., int32 or int64) for the calculation
1058+
original_dtype = image.dtype
1059+
image_safe = image.to(torch.int32)
1060+
1061+
# Calculate the depth map
1062+
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
1063+
1064+
# You may want to cast the final result to uint16, but casting to a
1065+
# larger int type (like int32) is sufficient to fix the overflow.
1066+
# depth_map = depth_map.to(torch.uint16) # Uncomment if uint16 is strictly required
1067+
return depth_map.to(original_dtype)
10481068

1049-
Args:
1050-
image (`Union[np.ndarray, torch.Tensor]`):
1051-
The RGB-like depth image to convert.
1069+
elif isinstance(image, np.ndarray):
1070+
# NumPy equivalent: Cast to a safe dtype (e.g., np.int32)
1071+
original_dtype = image.dtype
1072+
image_safe = image.astype(np.int32)
10521073

1053-
Returns:
1054-
`Union[np.ndarray, torch.Tensor]`:
1055-
The corresponding depth map.
1056-
"""
1057-
return image[:, :, 1] * 2**8 + image[:, :, 2]
1074+
# Calculate the depth map
1075+
depth_map = image_safe[:, :, 1] * 256 + image_safe[:, :, 2]
1076+
1077+
# depth_map = depth_map.astype(np.uint16) # Uncomment if uint16 is strictly required
1078+
return depth_map.astype(original_dtype)
1079+
else:
1080+
raise TypeError("Input image must be a torch.Tensor or np.ndarray")
10581081

10591082
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
10601083
r"""

0 commit comments

Comments
 (0)