diff --git a/containers/ei-models-runner/Dockerfile b/containers/ei-models-runner/Dockerfile index 1d134b12..3efdc747 100644 --- a/containers/ei-models-runner/Dockerfile +++ b/containers/ei-models-runner/Dockerfile @@ -2,7 +2,7 @@ # # SPDX-License-Identifier: MPL-2.0 -FROM public.ecr.aws/z9b3d4t5/inference-container-qc-adreno-702:4d7979284677b6bdb557abe8948fa1395dc89a63 +FROM public.ecr.aws/z9b3d4t5/inference-container-qc-adreno-702:39bcebb78de783cb602e1b361b71d6dafbc959b4 # Create the user and group needed to run the container as non-root RUN set -ex; \ diff --git a/src/arduino/app_bricks/camera_code_detection/README.md b/src/arduino/app_bricks/camera_code_detection/README.md index 786da81d..0b1b10c5 100644 --- a/src/arduino/app_bricks/camera_code_detection/README.md +++ b/src/arduino/app_bricks/camera_code_detection/README.md @@ -6,8 +6,8 @@ This Brick enables real-time barcode and QR code scanning from a camera video st The Camera Code Detection Brick allows you to: -- Capture frames from a USB camera. -- Configure camera settings (resolution and frame rate). +- Capture frames from a Camera (see Camera peripheral for supported cameras). +- Configure Camera settings (resolution and frame rate). - Define the type of code to detect: barcodes and/or QR codes. - Process detections with customizable callbacks. @@ -22,7 +22,7 @@ The Camera Code Detection Brick allows you to: ## Prerequisites -To use this Brick you should have a USB camera connected to your board. +To use this Brick you can choose to plug a camera to your board or use a network-connected camera. **Tip**: Use a USB-C® Hub with USB-A connectors to support commercial web cameras. @@ -37,9 +37,25 @@ def render_frame(frame): def handle_detected_code(frame, detection): ... -# Select the camera you want to use, its resolution and the max fps -detection = CameraCodeDetection(camera=0, resolution=(640, 360), fps=10) +detection = CameraCodeDetection() detection.on_frame(render_frame) detection.on_detection(handle_detected_code) -detection.start() + +App.run() ``` + +You can also select a specific camera to use: + +```python +from arduino.app_bricks.camera_code_detection import CameraCodeDetection + +def handle_detected_code(frame, detection): + ... + +# Select the camera you want to use, its resolution and the max fps +camera = Camera(camera="rtsp://...", resolution=(640, 360), fps=10) +detection = CameraCodeDetection(camera) +detection.on_detection(handle_detected_code) + +App.run() +``` \ No newline at end of file diff --git a/src/arduino/app_bricks/camera_code_detection/__init__.py b/src/arduino/app_bricks/camera_code_detection/__init__.py index e2c166fc..084984f1 100644 --- a/src/arduino/app_bricks/camera_code_detection/__init__.py +++ b/src/arduino/app_bricks/camera_code_detection/__init__.py @@ -2,7 +2,6 @@ # # SPDX-License-Identifier: MPL-2.0 -from .detection import Detection, CameraCodeDetection -from .utils import draw_bounding_boxes, draw_bounding_box +from .detection import CameraCodeDetection, Detection -__all__ = ["CameraCodeDetection", "Detection", "draw_bounding_boxes", "draw_bounding_box"] +__all__ = ["CameraCodeDetection", "Detection"] diff --git a/src/arduino/app_bricks/camera_code_detection/detection.py b/src/arduino/app_bricks/camera_code_detection/detection.py index bb020364..964b5870 100644 --- a/src/arduino/app_bricks/camera_code_detection/detection.py +++ b/src/arduino/app_bricks/camera_code_detection/detection.py @@ -6,12 +6,12 @@ import threading from typing import Callable -import cv2 from pyzbar.pyzbar import decode, ZBarSymbol, PyZbarError import numpy as np -from PIL.Image import Image +from PIL.Image import Image, fromarray -from arduino.app_peripherals.usb_camera import USBCamera +from arduino.app_peripherals.camera import Camera +from arduino.app_utils.image import greyscale from arduino.app_utils import brick, Logger logger = Logger("CameraCodeDetection") @@ -44,7 +44,7 @@ class CameraCodeDetection: """Scans a camera video feed for QR codes and/or barcodes. Args: - camera (USBCamera): The USB camera instance. If None, a default camera will be initialized. + camera (Camera): The camera instance to use for capturing video. If None, a default camera will be initialized. detect_qr (bool): Whether to detect QR codes. Defaults to True. detect_barcode (bool): Whether to detect barcodes. Defaults to True. @@ -55,7 +55,7 @@ class CameraCodeDetection: def __init__( self, - camera: USBCamera = None, + camera: Camera = None, detect_qr: bool = True, detect_barcode: bool = True, ): @@ -76,7 +76,7 @@ def __init__( self.already_seen_codes = set() - self._camera = camera if camera else USBCamera() + self._camera = camera if camera else Camera() def start(self): """Start the detector and begin scanning for codes.""" @@ -154,13 +154,13 @@ def loop(self): self._on_error(e) return - # Use grayscale for barcode/QR code detection - gs_frame = cv2.cvtColor(np.asarray(frame), cv2.COLOR_RGB2GRAY) - - self._on_frame(frame) + pil_frame = fromarray(frame) + self._on_frame(pil_frame) + # Use grayscale for barcode/QR code detection + gs_frame = greyscale(frame) detections = self._scan_frame(gs_frame) - self._on_detect(frame, detections) + self._on_detect(pil_frame, detections) def _on_frame(self, frame: Image): if self._on_frame_cb: @@ -170,7 +170,7 @@ def _on_frame(self, frame: Image): logger.error(f"Failed to run on_frame callback: {e}") self._on_error(e) - def _scan_frame(self, frame: cv2.typing.MatLike) -> list[Detection]: + def _scan_frame(self, frame: np.ndarray) -> list[Detection]: """Scan the frame for a single barcode or QR code.""" detections = [] diff --git a/src/arduino/app_bricks/camera_code_detection/examples/2_detection_list.py b/src/arduino/app_bricks/camera_code_detection/examples/2_detection_list.py index 6288d571..e3021eb9 100644 --- a/src/arduino/app_bricks/camera_code_detection/examples/2_detection_list.py +++ b/src/arduino/app_bricks/camera_code_detection/examples/2_detection_list.py @@ -19,4 +19,4 @@ def on_codes_detected(frame: Image, detections: list[Detection]): detector = CameraCodeDetection() detector.on_detect(on_codes_detected) -App.run() # This will block until the app is stopped +App.run() diff --git a/src/arduino/app_bricks/camera_code_detection/examples/3_detection_with_overrides.py b/src/arduino/app_bricks/camera_code_detection/examples/3_detection_with_overrides.py index 8a672470..fcd8ba3c 100644 --- a/src/arduino/app_bricks/camera_code_detection/examples/3_detection_with_overrides.py +++ b/src/arduino/app_bricks/camera_code_detection/examples/3_detection_with_overrides.py @@ -6,7 +6,7 @@ # EXAMPLE_REQUIRES = "Requires an USB webcam connected to the Arduino board." from PIL.Image import Image from arduino.app_utils.app import App -from arduino.app_peripherals.usb_camera import USBCamera +from arduino.app_peripherals.usb_camera import Camera from arduino.app_bricks.camera_code_detection import CameraCodeDetection, Detection @@ -17,7 +17,7 @@ def on_code_detected(frame: Image, detection: Detection): # e.g., draw a bounding box, save it to a database or log it. -camera = USBCamera(camera=0, resolution=(640, 360), fps=10) +camera = Camera(camera=2, resolution=(640, 360), fps=10) detector = CameraCodeDetection(camera) detector.on_detect(on_code_detected) diff --git a/src/arduino/app_bricks/object_detection/README.md b/src/arduino/app_bricks/object_detection/README.md index 3234ca67..9489e695 100644 --- a/src/arduino/app_bricks/object_detection/README.md +++ b/src/arduino/app_bricks/object_detection/README.md @@ -23,23 +23,24 @@ The Object Detection Brick allows you to: ```python import os from arduino.app_bricks.object_detection import ObjectDetection +from arduino.app_utils.image import draw_bounding_boxes object_detection = ObjectDetection() -# Image frame can be as bytes or PIL image -frame = os.read("path/to/your/image.jpg") +# Image can be provided as bytes or PIL.Image +img = os.read("path/to/your/image.jpg") -out = object_detection.detect(frame) -# is it possible to customize image type, confidence level and box overlap -# out = object_detection.detect(frame, image_type = "png", confidence = 0.35, overlap = 0.5) +out = object_detection.detect(img) +# You can also provide a confidence level +# out = object_detection.detect(frame, confidence = 0.35) if out and "detection" in out: for i, obj_det in enumerate(out["detection"]): - # For every object detected, get its details + # For every object detected, print its details detected_object = obj_det.get("class_name", None) - bounding_box = obj_det.get("bounding_box_xyxy", None) confidence = obj_det.get("confidence", None) + bounding_box = obj_det.get("bounding_box_xyxy", None) -# draw the bounding box and key points on the image -out_image = object_detection.draw_bounding_boxes(frame, out) +# Draw the bounding boxes +out_image = draw_bounding_boxes(img, out) ``` diff --git a/src/arduino/app_bricks/object_detection/__init__.py b/src/arduino/app_bricks/object_detection/__init__.py index 93f2e290..3640fa52 100644 --- a/src/arduino/app_bricks/object_detection/__init__.py +++ b/src/arduino/app_bricks/object_detection/__init__.py @@ -2,8 +2,7 @@ # # SPDX-License-Identifier: MPL-2.0 -from PIL import Image -from arduino.app_utils import brick, Logger, draw_bounding_boxes +from arduino.app_utils import brick, Logger from arduino.app_internal.core import EdgeImpulseRunnerFacade logger = Logger("ObjectDetection") @@ -54,19 +53,6 @@ def detect(self, image_bytes, image_type: str = "jpg", confidence: float = None) ret = super().infer_from_image(image_bytes, image_type) return self._extract_detection(ret, confidence) - def draw_bounding_boxes(self, image: Image.Image | bytes, detections: dict) -> Image.Image | None: - """Draw bounding boxes on an image enclosing detected objects using PIL. - - Args: - image: The input image to annotate. Can be a PIL Image object or raw image bytes. - detections: Detection results containing object labels and bounding boxes. - - Returns: - Image with bounding boxes and key points drawn. - None if no detection or invalid image. - """ - return draw_bounding_boxes(image, detections) - def _extract_detection(self, item, confidence: float = None): if not item: return None diff --git a/src/arduino/app_bricks/object_detection/examples/object_detection_example.py b/src/arduino/app_bricks/object_detection/examples/object_detection_example.py index f2ca3b9f..80b92b20 100644 --- a/src/arduino/app_bricks/object_detection/examples/object_detection_example.py +++ b/src/arduino/app_bricks/object_detection/examples/object_detection_example.py @@ -3,23 +3,24 @@ # SPDX-License-Identifier: MPL-2.0 # EXAMPLE_NAME = "Object Detection" +import os from arduino.app_bricks.object_detection import ObjectDetection +from arduino.app_utils.image import draw_bounding_boxes object_detection = ObjectDetection() -# Image frame can be as bytes or PIL image -with open("image.png", "rb") as f: - frame = f.read() +# Image can be provided as bytes or PIL.Image +img = os.read("path/to/your/image.jpg") -out = object_detection.detect(frame) -# is it possible to customize image type, confidence level and box overlap -# out = object_detection.detect(frame, image_type = "png", confidence = 0.35, overlap = 0.5) +out = object_detection.detect(img) +# You can also provide a confidence level +# out = object_detection.detect(frame, confidence = 0.35) if out and "detection" in out: for i, obj_det in enumerate(out["detection"]): - # For every object detected, get its details + # For every object detected, print its details detected_object = obj_det.get("class_name", None) - bounding_box = obj_det.get("bounding_box_xyxy", None) confidence = obj_det.get("confidence", None) + bounding_box = obj_det.get("bounding_box_xyxy", None) -# draw the bounding box and key points on the image -out_image = object_detection.draw_bounding_boxes(frame, out) +# Draw the bounding boxes +out_image = draw_bounding_boxes(img, out) diff --git a/src/arduino/app_bricks/video_imageclassification/__init__.py b/src/arduino/app_bricks/video_imageclassification/__init__.py index 87abab5e..437bc561 100644 --- a/src/arduino/app_bricks/video_imageclassification/__init__.py +++ b/src/arduino/app_bricks/video_imageclassification/__init__.py @@ -2,16 +2,21 @@ # # SPDX-License-Identifier: MPL-2.0 -from arduino.app_utils import brick, Logger -from arduino.app_internal.core import load_brick_compose_file, resolve_address -from arduino.app_internal.core import EdgeImpulseRunnerFacade -import threading import time +import json +import inspect +import threading +import socket from typing import Callable + from websockets.sync.client import connect, ClientConnection from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError -import json -import inspect + +from arduino.app_peripherals.camera import Camera +from arduino.app_internal.core import load_brick_compose_file, resolve_address +from arduino.app_internal.core import EdgeImpulseRunnerFacade +from arduino.app_utils.image import compress_to_jpeg +from arduino.app_utils import brick, Logger logger = Logger("VideoImageClassification") @@ -25,10 +30,11 @@ class VideoImageClassification: ALL_HANDLERS_KEY = "__ALL" - def __init__(self, confidence: float = 0.3, debounce_sec: float = 0.0): + def __init__(self, camera: Camera = None, confidence: float = 0.3, debounce_sec: float = 0.0): """Initialize the VideoImageClassification class. Args: + camera (Camera): The camera instance to use for capturing video. If None, a default camera will be initialized. confidence (float): The minimum confidence level for a classification to be considered valid. Default is 0.3. debounce_sec (float): The minimum time in seconds between consecutive detections of the same object to avoid multiple triggers. Default is 0 seconds. @@ -36,6 +42,8 @@ def __init__(self, confidence: float = 0.3, debounce_sec: float = 0.0): Raises: RuntimeError: If the host address could not be resolved. """ + self._camera = camera if camera else Camera() + self._confidence = confidence self._debounce_sec = debounce_sec self._last_detected = {} @@ -114,40 +122,26 @@ def on_detect(self, object: str, callback: Callable[[], None]): self._handlers[object] = callback def start(self): - """Start the classification stream. - - This only sets the internal running flag. You must call - `execute` in a loop or a separate thread to actually begin receiving classification results. - """ + """Start the classification.""" + self._camera.start() self._is_running.set() def stop(self): - """Stop the classification stream and release resources. - - This clears the running flag. Any active `execute` loop - will exit gracefully at its next iteration. - """ + """Stop the classification and release resources.""" self._is_running.clear() + self._camera.stop() - def execute(self): - """Run the main classification loop. - - Behavior: - - Opens a WebSocket connection to the model runner. - - Receives classification messages in real time. - - Filters classifications below the confidence threshold. - - Applies debounce rules before invoking callbacks. - - Retries on transient connection errors until stopped. - - Exceptions: - ConnectionClosedOK: - Raised to exit when the server closes the connection cleanly. - ConnectionClosedError, TimeoutError, ConnectionRefusedError: - Logged and retried with backoff. + @brick.execute + def classification_loop(self): + """Classification main loop. + + Maintains WebSocket connection to the model runner and processes classification messages. + Retries on connection errors until stopped. """ while self._is_running.is_set(): try: with connect(self._uri) as ws: + logger.info("WebSocket connection established") while self._is_running.is_set(): try: message = ws.recv() @@ -157,21 +151,56 @@ def execute(self): except ConnectionClosedOK: raise except (TimeoutError, ConnectionRefusedError, ConnectionClosedError): - logger.warning(f"Connection lost. Retrying...") + logger.warning(f"WebSocket connection lost. Retrying...") raise except Exception as e: logger.exception(f"Failed to process detection: {e}") except ConnectionClosedOK: - logger.debug(f"Disconnected cleanly, exiting WebSocket read loop.") + logger.debug(f"WebSocket disconnected cleanly, exiting loop.") return except (TimeoutError, ConnectionRefusedError, ConnectionClosedError): logger.debug(f"Waiting for model runner. Retrying...") - import time - time.sleep(2) continue except Exception as e: logger.exception(f"Failed to establish WebSocket connection to {self._host}: {e}") + time.sleep(2) + + @brick.execute + def camera_loop(self): + """Camera main loop. + + Captures images from the camera and forwards them over the TCP connection. + Retries on connection errors until stopped. + """ + while self._is_running.is_set(): + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket: + tcp_socket.connect((self._host, "5050")) + logger.info(f"TCP connection established to {self._host}:5050") + + while self._is_running.is_set(): + try: + frame = self._camera.capture() + if frame is None: + time.sleep(0.01) # Brief sleep if no image available + continue + + jpeg_frame = compress_to_jpeg(frame) + tcp_socket.sendall(jpeg_frame.tobytes()) + + except (BrokenPipeError, ConnectionResetError, OSError) as e: + logger.warning(f"TCP connection lost: {e}. Retrying...") + break + except Exception as e: + logger.exception(f"Error capturing/sending image: {e}") + + except (ConnectionRefusedError, OSError) as e: + logger.debug(f"TCP connection failed: {e}. Retrying in 2 seconds...") + time.sleep(2) + except Exception as e: + logger.exception(f"Unexpected error in TCP loop: {e}") + time.sleep(2) def _process_message(self, ws: ClientConnection, message: str): jmsg = json.loads(message) diff --git a/src/arduino/app_bricks/video_imageclassification/brick_compose.yaml b/src/arduino/app_bricks/video_imageclassification/brick_compose.yaml index 7e054acc..ff2a1495 100644 --- a/src/arduino/app_bricks/video_imageclassification/brick_compose.yaml +++ b/src/arduino/app_bricks/video_imageclassification/brick_compose.yaml @@ -9,11 +9,12 @@ services: max-size: "5m" max-file: "2" ports: - - ${BIND_ADDRESS:-0.0.0.0}:4912:4912 + - ${BIND_ADDRESS:-0.0.0.0}:5050:5050 # TCP input for video frames + - ${BIND_ADDRESS:-0.0.0.0}:4912:4912 # Embedded UI port volumes: - "${CUSTOM_MODEL_PATH:-/home/arduino/.arduino-bricks/ei-models/}:${CUSTOM_MODEL_PATH:-/home/arduino/.arduino-bricks/ei-models/}" - "/run/udev:/run/udev" - command: ["--model-file", "${EI_CLASSIFICATION_MODEL:-/models/ootb/ei/mobilenet-v2-224px.eim}", "--dont-print-predictions", "--mode", "streaming", "--preview-original-resolution", "--camera", "${VIDEO_DEVICE:-/dev/video1}"] + command: ["--model-file", "${EI_CLASSIFICATION_MODEL:-/models/ootb/ei/mobilenet-v2-224px.eim}", "--dont-print-predictions", "--mode", "streaming-tcp-server", "--preview-original-resolution"] healthcheck: test: [ "CMD-SHELL", "wget -q --spider http://ei-video-classification-runner:4912 || exit 1" ] interval: 2s diff --git a/src/arduino/app_bricks/video_objectdetection/__init__.py b/src/arduino/app_bricks/video_objectdetection/__init__.py index b9372e04..7c41df73 100644 --- a/src/arduino/app_bricks/video_objectdetection/__init__.py +++ b/src/arduino/app_bricks/video_objectdetection/__init__.py @@ -2,16 +2,21 @@ # # SPDX-License-Identifier: MPL-2.0 -from arduino.app_utils import brick, Logger -from arduino.app_internal.core import load_brick_compose_file, resolve_address -from arduino.app_internal.core import EdgeImpulseRunnerFacade import time +import json +import inspect import threading +import socket from typing import Callable + from websockets.sync.client import connect, ClientConnection from websockets.exceptions import ConnectionClosedOK, ConnectionClosedError -import json -import inspect + +from arduino.app_peripherals.camera import Camera +from arduino.app_internal.core import load_brick_compose_file, resolve_address +from arduino.app_internal.core import EdgeImpulseRunnerFacade +from arduino.app_utils.image.adjustments import compress_to_jpeg +from arduino.app_utils import brick, Logger logger = Logger("VideoObjectDetection") @@ -30,16 +35,19 @@ class VideoObjectDetection: ALL_HANDLERS_KEY = "__ALL" - def __init__(self, confidence: float = 0.3, debounce_sec: float = 0.0): + def __init__(self, camera: Camera = None, confidence: float = 0.3, debounce_sec: float = 0.0): """Initialize the VideoObjectDetection class. Args: + camera (Camera): The camera instance to use for capturing video. If None, a default camera will be initialized. confidence (float): Confidence level for detection. Default is 0.3 (30%). debounce_sec (float): Minimum seconds between repeated detections of the same object. Default is 0 seconds. Raises: RuntimeError: If the host address could not be resolved. """ + self._camera = camera if camera else Camera() + self._confidence = confidence self._debounce_sec = debounce_sec self._last_detected: dict[str, float] = {} @@ -107,32 +115,25 @@ def on_detect_all(self, callback: Callable[[dict], None]): def start(self): """Start the video object detection process.""" + self._camera.start() self._is_running.set() def stop(self): - """Stop the video object detection process.""" + """Stop the video object detection process and release resources.""" self._is_running.clear() + self._camera.stop() + + @brick.execute + def object_detection_loop(self): + """Object detection main loop. - def execute(self): - """Connect to the model runner and process messages until `stop` is called. - - Behavior: - - Establishes a WebSocket connection to the runner. - - Parses ``"hello"`` messages to capture model metadata and optionally - performs a threshold override to align the runner with the local setting. - - Parses ``"classification"`` messages, filters detections by confidence, - applies debounce, then invokes registered callbacks. - - Retries on transient WebSocket errors while running. - - Exceptions: - ConnectionClosedOK: - Propagated to exit cleanly when the server closes the connection. - ConnectionClosedError, TimeoutError, ConnectionRefusedError: - Logged and retried with a short backoff while running. + Maintains WebSocket connection to the model runner and processes object detection messages. + Retries on connection errors until stopped. """ while self._is_running.is_set(): try: with connect(self._uri) as ws: + logger.info("WebSocket connection established") while self._is_running.is_set(): try: message = ws.recv() @@ -142,21 +143,56 @@ def execute(self): except ConnectionClosedOK: raise except (TimeoutError, ConnectionRefusedError, ConnectionClosedError): - logger.warning(f"Connection lost. Retrying...") + logger.warning(f"WebSocket connection lost. Retrying...") raise except Exception as e: logger.exception(f"Failed to process detection: {e}") except ConnectionClosedOK: - logger.debug(f"Disconnected cleanly, exiting WebSocket read loop.") + logger.debug(f"WebSocket disconnected cleanly, exiting loop.") return except (TimeoutError, ConnectionRefusedError, ConnectionClosedError): logger.debug(f"Waiting for model runner. Retrying...") - import time - time.sleep(2) continue except Exception as e: logger.exception(f"Failed to establish WebSocket connection to {self._host}: {e}") + time.sleep(2) + + @brick.execute + def camera_loop(self): + """Camera main loop. + + Captures images from the camera and forwards them over the TCP connection. + Retries on connection errors until stopped. + """ + while self._is_running.is_set(): + try: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as tcp_socket: + tcp_socket.connect((self._host, "5050")) + logger.info(f"TCP connection established to {self._host}:5050") + + while self._is_running.is_set(): + try: + frame = self._camera.capture() + if frame is None: + time.sleep(0.01) # Brief sleep if no image available + continue + + jpeg_frame = compress_to_jpeg(frame) + tcp_socket.sendall(jpeg_frame.tobytes()) + + except (BrokenPipeError, ConnectionResetError, OSError) as e: + logger.warning(f"TCP connection lost: {e}. Retrying...") + break + except Exception as e: + logger.exception(f"Error capturing/sending image: {e}") + + except (ConnectionRefusedError, OSError) as e: + logger.debug(f"TCP connection failed: {e}. Retrying in 2 seconds...") + time.sleep(2) + except Exception as e: + logger.exception(f"Unexpected error in TCP loop: {e}") + time.sleep(2) def _process_message(self, ws: ClientConnection, message: str): jmsg = json.loads(message) diff --git a/src/arduino/app_bricks/video_objectdetection/brick_compose.yaml b/src/arduino/app_bricks/video_objectdetection/brick_compose.yaml index dbca6363..053e05e9 100644 --- a/src/arduino/app_bricks/video_objectdetection/brick_compose.yaml +++ b/src/arduino/app_bricks/video_objectdetection/brick_compose.yaml @@ -9,11 +9,12 @@ services: max-size: "5m" max-file: "2" ports: - - ${BIND_ADDRESS:-0.0.0.0}:4912:4912 + - ${BIND_ADDRESS:-0.0.0.0}:5050:5050 # TCP input for video frames + - ${BIND_ADDRESS:-0.0.0.0}:4912:4912 # Embedded UI port volumes: - "${CUSTOM_MODEL_PATH:-/home/arduino/.arduino-bricks/ei-models/}:${CUSTOM_MODEL_PATH:-/home/arduino/.arduino-bricks/ei-models/}" - "/run/udev:/run/udev" - command: ["--model-file", "${EI_OBJ_DETECTION_MODEL:-/models/ootb/ei/yolo-x-nano.eim}", "--dont-print-predictions", "--mode", "streaming", "--force-target", "--preview-original-resolution", "--camera", "${VIDEO_DEVICE:-/dev/video1}"] + command: ["--model-file", "${EI_OBJ_DETECTION_MODEL:-/models/ootb/ei/yolo-x-nano.eim}", "--dont-print-predictions", "--mode", "streaming-tcp-server", "--preview-original-resolution"] healthcheck: test: [ "CMD-SHELL", "wget -q --spider http://ei-video-obj-detection-runner:4912 || exit 1" ] interval: 2s diff --git a/src/arduino/app_bricks/visual_anomaly_detection/examples/object_detection_example.py b/src/arduino/app_bricks/visual_anomaly_detection/examples/object_detection_example.py deleted file mode 100644 index 5dc0d2cc..00000000 --- a/src/arduino/app_bricks/visual_anomaly_detection/examples/object_detection_example.py +++ /dev/null @@ -1,25 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA -# -# SPDX-License-Identifier: MPL-2.0 - -# EXAMPLE_NAME = "Object Detection" -import os -from arduino.app_bricks.object_detection import ObjectDetection - -object_detection = ObjectDetection() - -# Image frame can be as bytes or PIL image -frame = os.read("path/to/your/image.jpg") - -out = object_detection.detect(frame) -# is it possible to customize image type, confidence level and box overlap -# out = object_detection.detect(frame, image_type = "png", confidence = 0.35, overlap = 0.5) -if out and "detection" in out: - for i, obj_det in enumerate(out["detection"]): - # For every object detected, get its details - detected_object = obj_det.get("class_name", None) - bounding_box = obj_det.get("bounding_box_xyxy", None) - confidence = obj_det.get("confidence", None) - -# draw the bounding box and key points on the image -out_image = object_detection.draw_bounding_boxes(frame, out) diff --git a/src/arduino/app_bricks/visual_anomaly_detection/examples/visual_anomaly_example.py b/src/arduino/app_bricks/visual_anomaly_detection/examples/visual_anomaly_example.py new file mode 100644 index 00000000..cbab3310 --- /dev/null +++ b/src/arduino/app_bricks/visual_anomaly_detection/examples/visual_anomaly_example.py @@ -0,0 +1,24 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Visual Anomaly Detection" +import os +from arduino.app_bricks.visual_anomaly_detection import VisualAnomalyDetection +from arduino.app_utils.image import draw_anomaly_markers + +anomaly_detection = VisualAnomalyDetection() + +# Image can be provided as bytes or PIL.Image +img = os.read("path/to/your/image.jpg") + +out = anomaly_detection.detect(img) +if out and "detection" in out: + for i, anomaly in enumerate(out["detection"]): + # For every anomaly detected, print its details + detected_anomaly = anomaly.get("class_name", None) + score = anomaly.get("score", None) + bounding_box = anomaly.get("bounding_box_xyxy", None) + +# Draw the bounding boxes +out_image = draw_anomaly_markers(img, out) diff --git a/src/arduino/app_internal/core/ei.py b/src/arduino/app_internal/core/ei.py index 825b8435..8753a317 100644 --- a/src/arduino/app_internal/core/ei.py +++ b/src/arduino/app_internal/core/ei.py @@ -5,8 +5,8 @@ import requests import io from arduino.app_internal.core import load_brick_compose_file, resolve_address -from arduino.app_utils import get_image_bytes, get_image_type, HttpClient -from arduino.app_utils import Logger +from arduino.app_utils.image import get_image_bytes, get_image_type +from arduino.app_utils import Logger, HttpClient logger = Logger(__name__) diff --git a/src/arduino/app_internal/pipeline/pipeline.py b/src/arduino/app_internal/pipeline/pipeline.py index 58e2cbc4..027e4eb0 100644 --- a/src/arduino/app_internal/pipeline/pipeline.py +++ b/src/arduino/app_internal/pipeline/pipeline.py @@ -177,11 +177,13 @@ def _run_loop(self, loop_ready_event: threading.Event): self._loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) self._loop.run_until_complete(self._loop.shutdown_asyncgens()) - self._loop.close() - logger.debug("Internal event loop stopped.") except Exception as e: logger.exception(f"Error during event loop cleanup: {e}") - self._loop = None + finally: + if self._loop and not self._loop.is_closed(): + self._loop.close() + self._loop = None + logger.debug("Internal event loop stopped.") async def _async_run_pipeline(self): """The main async logic using Adapters.""" diff --git a/src/arduino/app_peripherals/camera/README.md b/src/arduino/app_peripherals/camera/README.md new file mode 100644 index 00000000..a6512a2d --- /dev/null +++ b/src/arduino/app_peripherals/camera/README.md @@ -0,0 +1,149 @@ +# Camera + +The `Camera` peripheral provides a unified abstraction for capturing images from different camera types and protocols. + +## Features + +- **Universal Interface**: Single API for V4L/USB, IP cameras, and WebSocket cameras +- **Automatic Detection**: Selects appropriate camera implementation based on source +- **Multiple Protocols**: Supports V4L, RTSP, HTTP/MJPEG, and WebSocket streams +- **Thread-Safe**: Safe concurrent access with proper locking +- **Context Manager**: Automatic resource management + +## Quick Start + +Instantiate the default camera: +```python +from arduino.app_peripherals.camera import Camera + +# Default camera (V4L camera at index 0) +camera = Camera() +``` + +Camera needs to be started and stopped explicitly: + +```python +# Specify camera and configuration +camera = Camera(0, resolution=(640, 480), fps=15) +camera.start() + +image = camera.capture() + +camera.stop() +``` + +Or you can leverage context support for doing that automatically: +```python +with Camera(source, **options) as camera: + frame = camera.capture() + if frame is not None: + print(f"Captured frame with shape: {frame.shape}") + # Camera automatically stopped when exiting +``` + +## Frame Adjustments + +The `adjustments` parameter allows you to apply custom transformations to captured frames. This parameter accepts a callable that takes a numpy array (the frame) and returns a modified numpy array. It's also possible to build adjustment pipelines by concatenating these functions with the pipe (|) operator + +```python +import cv2 +from arduino.app_peripherals.camera import Camera +from arduino.app_utils.image import greyscaled + + +def blurred(): + def apply_blur(frame): + return cv2.GaussianBlur(frame, (15, 15), 0) + return PipeableFunction(apply_blur) + +# Using adjustments with Camera +with Camera(0, adjustments=greyscaled) as camera: + frame = camera.capture() + # frame is now grayscale + +# Or with multiple transformations +with Camera(0, adjustments=greyscaled | blurred) as camera: + frame = camera.capture() + # frame is now greyscaled and blurred +``` + +See the arduino.app_utils.image module for more supported adjustments. + +## Camera Types +The Camera class provides automatic camera type detection based on the format of its source argument. keyword arguments will be propagated to the underlying implementation. + +Note: Camera's constructor arguments (except those in its signature) must be provided in keyword format to forward them correctly to the specific camera implementations. + +The underlying camera implementations can also be instantiated explicitly (V4LCamera, IPCamera and WebSocketCamera), if needed. + +### V4L Cameras +For local USB cameras and V4L-compatible devices. + +**Features:** +- Supports cameras compatible with the Video4Linux2 drivers + +```python +camera = Camera(0) # Camera index +camera = Camera("/dev/video0") # Device path +camera = V4LCamera(0) +``` + +### IP Cameras +For network cameras supporting RTSP (Real-Time Streaming Protocol) and HLS (HTTP Live Streaming). + +**Features:** +- Supports capturing RTSP, HLS streams +- Authentication support +- Automatic reconnection + +```python +camera = Camera("rtsp://admin:secret@192.168.1.100/stream") +camera = Camera("http://camera.local/stream", + username="admin", password="secret") +camera = IPCamera("http://camera.local/stream", + username="admin", password="secret") +``` + +### WebSocket Cameras +For hosting a WebSocket server that receives frames from a single client at a time. + +**Features:** +- **Single client limitation**: Only one client can connect at a time +- Stream data from any client with WebSockets support +- Base64, binary, and JSON frame formats +- Supports 8-bit images (e.g. JPEG, PNG 8-bit) + +```python +camera = Camera("ws://0.0.0.0:8080", timeout=5) +camera = WebSocketCamera("0.0.0.0", 8080, timeout=5) +``` + +Client implementation example: +```python +import time +import base64 +import cv2 +import websockets.sync.client as wsclient +import websockets.exceptions as wsexc + + +# Open camera +camera = cv2.VideoCapture(0) +with wsclient.connect("ws://:8080") as websocket: + while True: + time.sleep(1.0 / 15.0) # 15 FPS + ret, frame = camera.read() + if ret: + # Compress frame to JPEG + _, buffer = cv2.imencode('.jpg', frame) + # Convert to base64 + jpeg_b64 = base64.b64encode(buffer).decode('utf-8') + try: + websocket.send(jpeg_b64) + except wsexc.ConnectionClosed: + break +``` + +## Migration from Legacy Camera + +The new Camera abstraction is backward compatible with the existing Camera implementation. Existing code using the old API will continue to work, but will use the new Camera backend. New code should use the improved abstraction for better flexibility and features. diff --git a/src/arduino/app_peripherals/camera/__init__.py b/src/arduino/app_peripherals/camera/__init__.py new file mode 100644 index 00000000..ada0b326 --- /dev/null +++ b/src/arduino/app_peripherals/camera/__init__.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +from .camera import Camera +from .v4l_camera import V4LCamera +from .ip_camera import IPCamera +from .websocket_camera import WebSocketCamera +from .errors import * + +__all__ = [ + "Camera", + "V4LCamera", + "IPCamera", + "WebSocketCamera", + "CameraError", + "CameraConfigError", + "CameraOpenError", + "CameraReadError", + "CameraTransformError", +] diff --git a/src/arduino/app_peripherals/camera/base_camera.py b/src/arduino/app_peripherals/camera/base_camera.py new file mode 100644 index 00000000..f37e51ce --- /dev/null +++ b/src/arduino/app_peripherals/camera/base_camera.py @@ -0,0 +1,158 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +import threading +import time +from abc import ABC, abstractmethod +from typing import Optional, Callable +import numpy as np + +from arduino.app_utils import Logger + +from .errors import CameraOpenError, CameraTransformError + +logger = Logger("Camera") + + +class BaseCamera(ABC): + """ + Abstract base class for camera implementations. + + This class defines the common interface that all camera implementations must follow, + providing a unified API regardless of the underlying camera protocol or type. + """ + + def __init__( + self, + resolution: tuple[int, int] = (640, 480), + fps: int = 10, + adjustments: Callable[[np.ndarray], np.ndarray] = None, + ): + """ + Initialize the camera base. + + Args: + resolution (tuple, optional): Resolution as (width, height). None uses default resolution. + fps (int): Frames per second to capture from the camera. + adjustments (callable, optional): Function or function pipeline to adjust frames that takes + a numpy array and returns a numpy array. Default: None + """ + self.resolution = resolution + self.fps = fps + self.adjustments = adjustments + self.logger = logger # This will be overridden by subclasses if needed + + self._camera_lock = threading.Lock() + self._is_started = False + self._last_capture_time = time.monotonic() + self._desired_interval = 1.0 / fps if fps > 0 else 0 + + def start(self) -> None: + """Start the camera capture.""" + with self._camera_lock: + if self._is_started: + return + + try: + self._open_camera() + self._is_started = True + self._last_capture_time = time.monotonic() + self.logger.info(f"Successfully started {self.__class__.__name__}") + except Exception as e: + raise CameraOpenError(f"Failed to start camera: {e}") + + def stop(self) -> None: + """Stop the camera and release resources.""" + with self._camera_lock: + if not self._is_started: + return + + try: + self._close_camera() + self._is_started = False + self.logger.info(f"Stopped {self.__class__.__name__}") + except Exception as e: + self.logger.warning(f"Error stopping camera: {e}") + + def capture(self) -> Optional[np.ndarray]: + """ + Capture a frame from the camera, respecting the configured FPS. + + Returns: + Numpy array or None if no frame is available. + """ + frame = self._extract_frame() + if frame is None: + return None + return frame + + def is_started(self) -> bool: + """Check if the camera is started.""" + return self._is_started + + def stream(self): + """ + Continuously capture frames from the camera. + + This is a generator that yields frames continuously while the camera is started. + Built on top of capture() for convenience. + + Yields: + np.ndarray: Video frames as numpy arrays. + """ + while self._is_started: + frame = self.capture() + if frame is not None: + yield frame + + def _extract_frame(self) -> np.ndarray | None: + """Extract a frame with FPS throttling and post-processing.""" + with self._camera_lock: + # FPS throttling + if self._desired_interval > 0: + current_time = time.monotonic() + elapsed = current_time - self._last_capture_time + if elapsed < self._desired_interval: + time.sleep(self._desired_interval - elapsed) + + if not self._is_started: + return None + + frame = self._read_frame() + if frame is None: + return None + + self._last_capture_time = time.monotonic() + + if self.adjustments is not None: + try: + frame = self.adjustments(frame) + except Exception as e: + raise CameraTransformError(f"Frame transformation failed ({self.adjustments}): {e}") + + return frame + + @abstractmethod + def _open_camera(self) -> None: + """Open the camera connection. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _close_camera(self) -> None: + """Close the camera connection. Must be implemented by subclasses.""" + pass + + @abstractmethod + def _read_frame(self) -> Optional[np.ndarray]: + """Read a single frame from the camera. Must be implemented by subclasses.""" + pass + + def __enter__(self): + """Context manager entry.""" + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + """Context manager exit.""" + self.stop() diff --git a/src/arduino/app_peripherals/camera/camera.py b/src/arduino/app_peripherals/camera/camera.py new file mode 100644 index 00000000..733062e4 --- /dev/null +++ b/src/arduino/app_peripherals/camera/camera.py @@ -0,0 +1,122 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +from collections.abc import Callable +from urllib.parse import urlparse + +import numpy as np + +from .base_camera import BaseCamera +from .errors import CameraConfigError + + +class Camera: + """ + Unified Camera class that can be configured for different camera types. + + This class serves as both a factory and a wrapper, automatically creating + the appropriate camera implementation based on the provided configuration. + + Supports: + - V4L Cameras (local cameras connected to the system), the default + - IP Cameras (network-based cameras via RTSP, HLS) + - WebSocket Cameras (input video streams via WebSocket client) + + Note: constructor arguments (except those in signature) must be provided in + keyword format to forward them correctly to the specific camera implementations. + """ + + def __new__( + cls, + source: str | int = 0, + resolution: tuple[int, int] = (640, 480), + fps: int = 10, + adjustments: Callable[[np.ndarray], np.ndarray] = None, + **kwargs, + ) -> BaseCamera: + """Create a camera instance based on the source type. + + Args: + source (Union[str, int]): Camera source identifier. Supports: + - int: V4L camera index (e.g., 0, 1) + - str: V4L camera index (e.g., "0", "1") or device path (e.g., "/dev/video0") + - str: URL for IP cameras (e.g., "rtsp://...", "http://...") + - str: WebSocket URL for input streams (e.g., "ws://0.0.0.0:8080") + resolution (tuple, optional): Frame resolution as (width, height). + Default: (640, 480) + fps (int, optional): Target frames per second. Default: 10 + adjustments (callable, optional): Function pipeline to adjust frames that takes a + numpy array and returns a numpy array. Default: None + **kwargs: Camera-specific configuration parameters grouped by type: + V4L Camera Parameters: + device (int, optional): V4L device index override. Default: 0. + IP Camera Parameters: + url (str): Camera stream URL + username (str, optional): Authentication username + password (str, optional): Authentication password + timeout (float, optional): Connection timeout in seconds. Default: 10.0 + WebSocket Camera Parameters: + host (str, optional): WebSocket server host. Default: "0.0.0.0" + port (int, optional): WebSocket server port. Default: 8080 + timeout (float, optional): Connection timeout in seconds. Default: 10.0 + frame_format (str, optional): Expected frame format ("base64", "binary", + "json"). Default: "base64" + + Returns: + BaseCamera: Appropriate camera implementation instance + + Raises: + CameraConfigError: If source type is not supported or parameters are invalid + CameraOpenError: If the camera cannot be opened + + Examples: + V4L Camera: + + ```python + camera = Camera(0, resolution=(640, 480), fps=30) + camera = Camera("/dev/video1", fps=15) + ``` + + IP Camera: + + ```python + camera = Camera("rtsp://192.168.1.100:554/stream", username="admin", password="secret", timeout=15.0) + camera = Camera("http://192.168.1.100:8080/video.mp4") + ``` + + WebSocket Camera: + + ```python + camera = Camera("ws://0.0.0.0:8080", frame_format="json") + camera = Camera("ws://192.168.1.100:8080", timeout=5) + ``` + """ + if isinstance(source, int) or (isinstance(source, str) and source.isdigit()): + # V4L Camera + from .v4l_camera import V4LCamera + + return V4LCamera(source, resolution=resolution, fps=fps, adjustments=adjustments, **kwargs) + elif isinstance(source, str): + parsed = urlparse(source) + if parsed.scheme in ["http", "https", "rtsp"]: + # IP Camera + from .ip_camera import IPCamera + + return IPCamera(source, resolution=resolution, fps=fps, adjustments=adjustments, **kwargs) + elif parsed.scheme in ["ws", "wss"]: + # WebSocket Camera - extract host and port from URL + from .websocket_camera import WebSocketCamera + + host = parsed.hostname or "localhost" + port = parsed.port or 8080 + return WebSocketCamera(host=host, port=port, resolution=resolution, fps=fps, adjustments=adjustments, **kwargs) + elif source.startswith("/dev/video") or source.isdigit(): + # V4L device path or index as string + from .v4l_camera import V4LCamera + + return V4LCamera(source, resolution=resolution, fps=fps, adjustments=adjustments, **kwargs) + else: + raise CameraConfigError(f"Unsupported camera source: {source}") + else: + raise CameraConfigError(f"Invalid source type: {type(source)}") diff --git a/src/arduino/app_peripherals/camera/errors.py b/src/arduino/app_peripherals/camera/errors.py new file mode 100644 index 00000000..6b20999f --- /dev/null +++ b/src/arduino/app_peripherals/camera/errors.py @@ -0,0 +1,33 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + + +class CameraError(Exception): + """Base exception for camera-related errors.""" + + pass + + +class CameraOpenError(CameraError): + """Exception raised when the camera cannot be opened.""" + + pass + + +class CameraReadError(CameraError): + """Exception raised when reading from camera fails.""" + + pass + + +class CameraConfigError(CameraError): + """Exception raised when camera configuration is invalid.""" + + pass + + +class CameraTransformError(CameraError): + """Exception raised when frame transformation fails.""" + + pass diff --git a/src/arduino/app_peripherals/camera/examples/1_initialize.py b/src/arduino/app_peripherals/camera/examples/1_initialize.py new file mode 100644 index 00000000..f720708c --- /dev/null +++ b/src/arduino/app_peripherals/camera/examples/1_initialize.py @@ -0,0 +1,18 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Initialize camera input" +# EXAMPLE_REQUIRES = "Requires a connected camera" +from arduino.app_peripherals.camera import Camera, V4LCamera + + +default = Camera() # Uses default camera (V4L) + +# The following two are equivalent +camera = Camera(2, resolution=(640, 480), fps=15) # Infers camera type +v4l = V4LCamera(2, (640, 480), 15) # Explicitly requests V4L camera + +# Note: Camera's constructor arguments (except those in its signature) +# must be provided in keyword format to forward them correctly to the +# specific camera implementations. diff --git a/src/arduino/app_peripherals/camera/examples/2_capture_image.py b/src/arduino/app_peripherals/camera/examples/2_capture_image.py new file mode 100644 index 00000000..f0e92f10 --- /dev/null +++ b/src/arduino/app_peripherals/camera/examples/2_capture_image.py @@ -0,0 +1,14 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Capture an image" +# EXAMPLE_REQUIRES = "Requires a connected camera" +import numpy as np +from arduino.app_peripherals.camera import Camera + + +camera = Camera() +camera.start() +image: np.ndarray = camera.capture() +camera.stop() diff --git a/src/arduino/app_peripherals/camera/examples/3_capture_video.py b/src/arduino/app_peripherals/camera/examples/3_capture_video.py new file mode 100644 index 00000000..4e38ad03 --- /dev/null +++ b/src/arduino/app_peripherals/camera/examples/3_capture_video.py @@ -0,0 +1,21 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Capture a video" +# EXAMPLE_REQUIRES = "Requires a connected camera" +import time +import numpy as np +from arduino.app_peripherals.camera import Camera + + +# Capture a video for 5 seconds at 15 FPS +camera = Camera(fps=15) +camera.start() + +start_time = time.time() +while time.time() - start_time < 5: + image: np.ndarray = camera.capture() + # You can process the image here if needed, e.g save it + +camera.stop() diff --git a/src/arduino/app_peripherals/camera/examples/4_capture_hls.py b/src/arduino/app_peripherals/camera/examples/4_capture_hls.py new file mode 100644 index 00000000..0a7a5e5d --- /dev/null +++ b/src/arduino/app_peripherals/camera/examples/4_capture_hls.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Capture an HLS (HTTP Live Stream) video" +import time +import numpy as np +from arduino.app_peripherals.camera import Camera + + +# Capture a freely available HLS playlist for testing +# Note: Public streams can be unreliable and may go offline without notice. +url = "https://demo.unified-streaming.com/k8s/features/stable/video/tears-of-steel/tears-of-steel.ism/.m3u8" + +camera = Camera(url) +camera.start() + +start_time = time.time() +while time.time() - start_time < 5: + image: np.ndarray = camera.capture() + # You can process the image here if needed, e.g save it + +camera.stop() diff --git a/src/arduino/app_peripherals/camera/examples/5_capture_rtsp.py b/src/arduino/app_peripherals/camera/examples/5_capture_rtsp.py new file mode 100644 index 00000000..955e5e66 --- /dev/null +++ b/src/arduino/app_peripherals/camera/examples/5_capture_rtsp.py @@ -0,0 +1,23 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Capture an RTSP (Real-Time Streaming Protocol) video" +import time +import numpy as np +from arduino.app_peripherals.camera import Camera + + +# Capture a freely available RTSP stream for testing +# Note: Public streams can be unreliable and may go offline without notice. +url = "rtsp://170.93.143.139/rtplive/470011e600ef003a004ee33696235daa" + +camera = Camera(url) +camera.start() + +start_time = time.time() +while time.time() - start_time < 5: + image: np.ndarray = camera.capture() + # You can process the image here if needed, e.g save it + +camera.stop() diff --git a/src/arduino/app_peripherals/camera/examples/6_capture_websocket.py b/src/arduino/app_peripherals/camera/examples/6_capture_websocket.py new file mode 100644 index 00000000..14235760 --- /dev/null +++ b/src/arduino/app_peripherals/camera/examples/6_capture_websocket.py @@ -0,0 +1,20 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +# EXAMPLE_NAME = "Capture an input WebSocket video" +import time +import numpy as np +from arduino.app_peripherals.camera import Camera + + +# Expose a WebSocket camera stream for clients to connect to +camera = Camera("ws://0.0.0.0:8080", timeout=5) +camera.start() + +start_time = time.time() +while time.time() - start_time < 5: + image: np.ndarray = camera.capture() + # You can process the image here if needed, e.g save it + +camera.stop() diff --git a/src/arduino/app_peripherals/camera/ip_camera.py b/src/arduino/app_peripherals/camera/ip_camera.py new file mode 100644 index 00000000..3043439f --- /dev/null +++ b/src/arduino/app_peripherals/camera/ip_camera.py @@ -0,0 +1,147 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +import cv2 +import numpy as np +import requests +from urllib.parse import urlparse +from collections.abc import Callable + +from arduino.app_utils import Logger + +from .camera import BaseCamera +from .errors import CameraConfigError, CameraOpenError + +logger = Logger("IPCamera") + + +class IPCamera(BaseCamera): + """ + IP Camera implementation for network-based cameras. + + Supports RTSP, HTTP, and HTTPS camera streams. + Can handle authentication and various streaming protocols. + """ + + def __init__( + self, + url: str, + username: str | None = None, + password: str | None = None, + timeout: int = 10, + resolution: tuple[int, int] = (640, 480), + fps: int = 10, + adjustments: Callable[[np.ndarray], np.ndarray] = None, + ): + """ + Initialize IP camera. + + Args: + url: Camera stream URL (i.e. rtsp://..., http://..., https://...) + username: Optional authentication username + password: Optional authentication password + timeout: Connection timeout in seconds + resolution (tuple, optional): Resolution as (width, height). None uses default resolution. + fps (int): Frames per second to capture from the camera. + adjustments (callable, optional): Function or function pipeline to adjust frames that takes + a numpy array and returns a numpy array. Default: None + """ + super().__init__(resolution, fps, adjustments) + self.url = url + self.username = username + self.password = password + self.timeout = timeout + self.logger = logger + + self._cap = None + + self._validate_url() + + def _validate_url(self) -> None: + """Validate the camera URL format.""" + try: + parsed = urlparse(self.url) + if parsed.scheme not in ["http", "https", "rtsp"]: + raise CameraConfigError(f"Unsupported URL scheme: {parsed.scheme}") + except Exception as e: + raise CameraConfigError(f"Invalid URL format: {e}") + + def _open_camera(self) -> None: + """Open the IP camera connection.""" + url = self._build_url() + + # Test connectivity first for HTTP streams + if self.url.startswith(("http://", "https://")): + self._test_http_connectivity() + + self._cap = cv2.VideoCapture(url) + if not self._cap.isOpened(): + raise CameraOpenError(f"Failed to open IP camera: {self.url}") + + self._cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) # Reduce buffer to minimize latency + + # Test by reading one frame + ret, frame = self._cap.read() + if not ret or frame is None: + self._cap.release() + self._cap = None + raise CameraOpenError(f"Cannot read from IP camera: {self.url}") + + logger.info(f"Opened IP camera: {self.url}") + + def _build_url(self) -> str: + """Build URL with authentication if credentials provided.""" + # If no username or password provided as parameters, return original URL + if not self.username or not self.password: + return self.url + + parsed = urlparse(self.url) + + # Override any URL credentials if credentials are provided + auth_netloc = f"{self.username}:{self.password}@{parsed.hostname}" + if parsed.port: + auth_netloc += f":{parsed.port}" + + return f"{parsed.scheme}://{auth_netloc}{parsed.path}" + + def _test_http_connectivity(self) -> None: + """Test HTTP/HTTPS camera connectivity.""" + try: + auth = None + if self.username and self.password: + auth = (self.username, self.password) + + response = requests.head(self.url, auth=auth, timeout=self.timeout, allow_redirects=True) + + if response.status_code not in [200, 206]: # 206 for partial content + raise CameraOpenError(f"HTTP camera returned status {response.status_code}: {self.url}") + + except requests.RequestException as e: + raise CameraOpenError(f"Cannot connect to HTTP camera {self.url}: {e}") + + def _close_camera(self) -> None: + """Close the IP camera connection.""" + if self._cap is not None: + self._cap.release() + self._cap = None + + def _read_frame(self) -> np.ndarray | None: + """Read a frame from the IP camera with automatic reconnection.""" + if self._cap is None: + logger.info(f"No connection to IP camera {self.url}, attempting to reconnect") + try: + self._open_camera() + except Exception as e: + logger.error(f"Failed to reconnect to IP camera {self.url}: {e}") + return None + + ret, frame = self._cap.read() + if ret and frame is not None: + return frame + + if not self._cap.isOpened(): + logger.warning(f"IP camera connection dropped: {self.url}") + self._close_camera() # Will reconnect on next call + + return None diff --git a/src/arduino/app_peripherals/camera/v4l_camera.py b/src/arduino/app_peripherals/camera/v4l_camera.py new file mode 100644 index 00000000..0256f401 --- /dev/null +++ b/src/arduino/app_peripherals/camera/v4l_camera.py @@ -0,0 +1,172 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +import os +import re +import cv2 +import numpy as np +from collections.abc import Callable + +from arduino.app_utils import Logger + +from .camera import BaseCamera +from .errors import CameraOpenError, CameraReadError + +logger = Logger("V4LCamera") + + +class V4LCamera(BaseCamera): + """ + V4L (Video4Linux) camera implementation for USB and local cameras. + + This class handles USB cameras and other V4L-compatible devices on Linux systems. + It supports both device indices and device paths. + """ + + def __init__( + self, + device: str | int = 0, + resolution: tuple[int, int] = (640, 480), + fps: int = 10, + adjustments: Callable[[np.ndarray], np.ndarray] = None, + ): + """ + Initialize V4L camera. + + Args: + device: Camera identifier - can be: + - int: Camera index (e.g., 0, 1) + - str: Camera index as string or device path + resolution (tuple, optional): Resolution as (width, height). None uses default resolution. + fps (int, optional): Frames per second to capture from the camera. Default: 10. + adjustments (callable, optional): Function or function pipeline to adjust frames that takes + a numpy array and returns a numpy array. Default: None + """ + super().__init__(resolution, fps, adjustments) + self.device = self._resolve_camera_id(device) + self.logger = logger + + self._cap = None + + def _resolve_camera_id(self, device: str | int) -> int: + """ + Resolve camera identifier to a numeric device ID. + + Args: + device: Camera identifier + + Returns: + Numeric camera device ID + + Raises: + CameraOpenError: If camera cannot be resolved + """ + if isinstance(device, int): + return device + + if isinstance(device, str): + # If it's a numeric string, convert directly + if device.isdigit(): + device_idx = int(device) + # Validate using device index mapping + video_devices = self._get_video_devices_by_index() + if device_idx in video_devices: + return int(video_devices[device_idx]) + else: + # Fallback to direct device ID if mapping not available + return device_idx + + # If it's a device path like "/dev/video0" + if device.startswith("/dev/video"): + return int(device.replace("/dev/video", "")) + + raise CameraOpenError(f"Cannot resolve camera identifier: {device}") + + def _get_video_devices_by_index(self) -> dict[int, str]: + """ + Map camera indices to device numbers by reading /dev/v4l/by-id/. + + Returns: + Dict mapping index to device number + """ + devices_by_index = {} + directory_path = "/dev/v4l/by-id/" + + # Check if the directory exists + if not os.path.exists(directory_path): + logger.warning(f"Directory '{directory_path}' not found.") + return devices_by_index + + try: + entries = os.listdir(directory_path) + for entry in entries: + full_path = os.path.join(directory_path, entry) + + if os.path.islink(full_path): + # Find numeric index at end of filename + match = re.search(r"index(\d+)$", entry) + if match: + try: + index = int(match.group(1)) + resolved_path = os.path.realpath(full_path) + device_name = os.path.basename(resolved_path) + device_number = device_name.replace("video", "") + devices_by_index[index] = device_number + except ValueError: + logger.warning(f"Could not parse index from '{entry}'") + continue + except OSError as e: + logger.error(f"Error accessing directory '{directory_path}': {e}") + + return devices_by_index + + def _open_camera(self) -> None: + """Open the V4L camera connection.""" + self._cap = cv2.VideoCapture(self.device) + if not self._cap.isOpened(): + raise CameraOpenError(f"Failed to open V4L camera {self.device}") + + self._cap.set(cv2.CAP_PROP_BUFFERSIZE, 1) # Reduce buffer to minimize latency + + # Set resolution if specified + if self.resolution and self.resolution[0] and self.resolution[1]: + self._cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.resolution[0]) + self._cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.resolution[1]) + + # Verify resolution setting + actual_width = int(self._cap.get(cv2.CAP_PROP_FRAME_WIDTH)) + actual_height = int(self._cap.get(cv2.CAP_PROP_FRAME_HEIGHT)) + if actual_width != self.resolution[0] or actual_height != self.resolution[1]: + logger.warning( + f"Camera {self.device} resolution set to {actual_width}x{actual_height} " + f"instead of requested {self.resolution[0]}x{self.resolution[1]}" + ) + self.resolution = (actual_width, actual_height) + + if self.fps: + self._cap.set(cv2.CAP_PROP_FPS, self.fps) + + actual_fps = int(self._cap.get(cv2.CAP_PROP_FPS)) + if actual_fps != self.fps: + logger.warning(f"Camera {self.device} FPS set to {actual_fps} instead of requested {self.fps}") + self.fps = actual_fps + + logger.info(f"Opened V4L camera with index {self.device}") + + def _close_camera(self) -> None: + """Close the V4L camera connection.""" + if self._cap is not None: + self._cap.release() + self._cap = None + + def _read_frame(self) -> np.ndarray | None: + """Read a frame from the V4L camera.""" + if self._cap is None: + return None + + ret, frame = self._cap.read() + if not ret or frame is None: + raise CameraReadError(f"Failed to read from V4L camera {self.device}") + + return frame diff --git a/src/arduino/app_peripherals/camera/websocket_camera.py b/src/arduino/app_peripherals/camera/websocket_camera.py new file mode 100644 index 00000000..3b57ab99 --- /dev/null +++ b/src/arduino/app_peripherals/camera/websocket_camera.py @@ -0,0 +1,345 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +import json +import base64 +import threading +import queue +import time +import numpy as np +import cv2 +import websockets +import asyncio +from collections.abc import Callable +from concurrent.futures import CancelledError, TimeoutError + +from arduino.app_utils import Logger + +from .camera import BaseCamera +from .errors import CameraOpenError + +logger = Logger("WebSocketCamera") + + +class WebSocketCamera(BaseCamera): + """ + WebSocket Camera implementation that hosts a WebSocket server. + + This camera acts as a WebSocket server that receives frames from connected clients. + Only one client can be connected at a time. + + Clients must encode video frames in one of these formats: + - JPEG + - PNG + - WebP + - BMP + - TIFF + + The frames can be serialized in one of the following formats: + - Binary image data + - Base64 encoded images + - JSON messages with image data + """ + + def __init__( + self, + host: str = "0.0.0.0", + port: int = 8080, + timeout: int = 10, + frame_format: str = "binary", + resolution: tuple[int, int] = (640, 480), + fps: int = 10, + adjustments: Callable[[np.ndarray], np.ndarray] = None, + ): + """ + Initialize WebSocket camera server. + + Args: + host (str): Host address to bind the server to (default: "0.0.0.0") + port (int): Port to bind the server to (default: 8080) + timeout (int): Connection timeout in seconds (default: 10) + frame_format (str): Expected frame format from clients ("binary", "base64", "json") (default: "binary") + resolution (tuple, optional): Resolution as (width, height). None uses default resolution. + fps (int): Frames per second to capture from the camera. + adjustments (callable, optional): Function or function pipeline to adjust frames that takes + a numpy array and returns a numpy array. Default: None + """ + super().__init__(resolution, fps, adjustments) + + self.host = host + self.port = port + self.timeout = timeout + self.frame_format = frame_format + self.logger = logger + + self._frame_queue = queue.Queue(1) + self._server = None + self._loop = None + self._server_thread = None + self._stop_event = asyncio.Event() + self._client: websockets.ServerConnection = None + self._client_lock = asyncio.Lock() + + def _open_camera(self) -> None: + """Start the WebSocket server.""" + # Start server in separate thread with its own event loop + self._server_thread = threading.Thread(target=self._start_server_thread, daemon=True) + self._server_thread.start() + + # Wait for server to start + start_time = time.time() + start_timeout = 10 + while self._server is None and time.time() - start_time < start_timeout: + if self._server is not None: + break + time.sleep(0.1) + + if self._server is None: + raise CameraOpenError(f"Failed to start WebSocket server on {self.host}:{self.port}") + + def _start_server_thread(self) -> None: + """Run WebSocket server in its own thread with event loop.""" + try: + self._loop = asyncio.new_event_loop() + asyncio.set_event_loop(self._loop) + self._loop.run_until_complete(self._start_server()) + except Exception as e: + logger.error(f"WebSocket server thread error: {e}") + finally: + if self._loop and not self._loop.is_closed(): + self._loop.close() + + async def _start_server(self) -> None: + """Start the WebSocket server.""" + try: + self._stop_event.clear() + + self._server = await websockets.serve( + self._ws_handler, + self.host, + self.port, + open_timeout=self.timeout, + ping_timeout=self.timeout, + close_timeout=self.timeout, + ping_interval=20, + ) + + logger.info(f"WebSocket camera server started on {self.host}:{self.port}") + + await self._stop_event.wait() + + except Exception as e: + logger.error(f"Error starting WebSocket server: {e}") + raise + finally: + if self._server: + self._server.close() + await self._server.wait_closed() + + async def _ws_handler(self, conn: websockets.ServerConnection) -> None: + """Handle a connected WebSocket client. Only one client allowed at a time.""" + client_addr = f"{conn.remote_address[0]}:{conn.remote_address[1]}" + + async with self._client_lock: + if self._client is not None: + # Reject the new client + logger.warning(f"Rejecting client {client_addr}: only one client allowed at a time") + try: + await conn.send(json.dumps({"error": "Server busy", "message": "Only one client connection allowed at a time", "code": 1000})) + await conn.close(code=1000, reason="Server busy - only one client allowed") + except Exception as e: + logger.warning(f"Error sending rejection message to {client_addr}: {e}") + return + + # Accept the client + self._client = conn + + logger.info(f"Client connected: {client_addr}") + + try: + # Send welcome message + try: + await self._send_to_client({ + "status": "connected", + "message": "You are now connected to the camera server", + "frame_format": self.frame_format, + "resolution": self.resolution, + "fps": self.fps, + }) + except Exception as e: + logger.warning(f"Could not send welcome message to {client_addr}: {e}") + + async for message in conn: + frame = await self._parse_message(message) + if frame is not None: + # Drop old frames until there's room for the new one + while True: + try: + self._frame_queue.put_nowait(frame) + break + except queue.Full: + try: + # Drop oldest frame and try again + self._frame_queue.get_nowait() + except queue.Empty: + continue + + except websockets.exceptions.ConnectionClosed: + logger.info(f"Client disconnected: {client_addr}") + except Exception as e: + logger.warning(f"Error handling client {client_addr}: {e}") + finally: + async with self._client_lock: + if self._client == conn: + self._client = None + logger.info(f"Client removed: {client_addr}") + + async def _parse_message(self, message) -> np.ndarray | None: + """Parse WebSocket message to extract frame.""" + try: + if self.frame_format == "base64": + # Expect base64 encoded image + if isinstance(message, str): + image_data = base64.b64decode(message) + else: + image_data = base64.b64decode(message.decode()) + + # Decode image + nparr = np.frombuffer(image_data, np.uint8) + frame = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) + return frame + + elif self.frame_format == "binary": + # Expect raw binary image data + if isinstance(message, str): + image_data = message.encode() + else: + image_data = message + + nparr = np.frombuffer(image_data, np.uint8) + frame = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) + return frame + + elif self.frame_format == "json": + # Expect JSON with image data + if isinstance(message, bytes): + message = message.decode() + + data = json.loads(message) + + if "image" in data: + image_data = base64.b64decode(data["image"]) + nparr = np.frombuffer(image_data, np.uint8) + frame = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) + return frame + + elif "frame" in data: + # Handle different frame data formats + frame_data = data["frame"] + if isinstance(frame_data, str): + image_data = base64.b64decode(frame_data) + nparr = np.frombuffer(image_data, np.uint8) + frame = cv2.imdecode(nparr, cv2.IMREAD_UNCHANGED) + return frame + + return None + + except Exception as e: + logger.warning(f"Error parsing message: {e}") + return None + + def _close_camera(self): + """Stop the WebSocket server.""" + # Signal async stop event if it exists + if self._loop and not self._loop.is_closed(): + future = asyncio.run_coroutine_threadsafe(self._set_async_stop_event(), self._loop) + try: + future.result(timeout=1.0) + except CancelledError: + logger.debug(f"Error setting async stop event: CancelledError") + except TimeoutError: + logger.debug(f"Error setting async stop event: TimeoutError") + except Exception as e: + logger.warning(f"Error setting async stop event: {e}") + + # Wait for server thread to finish + if self._server_thread and self._server_thread.is_alive(): + self._server_thread.join(timeout=10.0) + + # Clear frame queue + try: + while True: + self._frame_queue.get_nowait() + except queue.Empty: + pass + + # Reset state + self._server = None + self._loop = None + self._client = None + + async def _set_async_stop_event(self): + """Set the async stop event and close the client connection.""" + # Send goodbye message and close the client connection + if self._client: + try: + # Send goodbye message before closing + await self._send_to_client({ + "status": "disconnecting", + "message": "Server is shutting down. Connection will be closed.", + }) + # Give a brief moment for the message to be sent + await asyncio.sleep(0.1) + except Exception as e: + logger.warning(f"Error closing client in stop event: {e}") + finally: + await self._client.close() + self._stop_event.set() + + def _read_frame(self) -> np.ndarray | None: + """Read a frame from the queue.""" + try: + # Get frame with short timeout to avoid blocking + frame = self._frame_queue.get(timeout=0.1) + return frame + except queue.Empty: + return None + + def _send_message_to_client(self, message: str | bytes | dict) -> None: + """ + Send a message to the connected client (if any). + + Args: + message: Message to send to the client + + Raises: + RuntimeError: If the event loop is not running or closed + ConnectionError: If no client is connected + Exception: For other communication errors + """ + if not self._loop or self._loop.is_closed(): + raise RuntimeError("WebSocket server event loop is not running") + + if self._client is None: + raise ConnectionError("No client connected to send message to") + + # Schedule message sending in the server's event loop + future = asyncio.run_coroutine_threadsafe(self._send_to_client(message), self._loop) + + try: + future.result(timeout=5.0) + except Exception as e: + logger.error(f"Error sending message to client: {e}") + raise + + async def _send_to_client(self, message: str | bytes | dict) -> None: + """Send message to a single client.""" + if isinstance(message, dict): + message = json.dumps(message) + + try: + await self._client.send(message) + except Exception as e: + logger.warning(f"Error sending to client: {e}") + raise diff --git a/src/arduino/app_peripherals/usb_camera/README.md b/src/arduino/app_peripherals/usb_camera/README.md index 502517ae..88d6cd83 100644 --- a/src/arduino/app_peripherals/usb_camera/README.md +++ b/src/arduino/app_peripherals/usb_camera/README.md @@ -1,5 +1,8 @@ # USB Camera +> [!NOTE] +> This peripheral is deprecated, use the Camera peripheral instead. + The `USBCamera` peripheral captures images and videos from a connected USB camera. ## Features diff --git a/src/arduino/app_peripherals/usb_camera/__init__.py b/src/arduino/app_peripherals/usb_camera/__init__.py index e71a9fc3..421cce2b 100644 --- a/src/arduino/app_peripherals/usb_camera/__init__.py +++ b/src/arduino/app_peripherals/usb_camera/__init__.py @@ -2,30 +2,22 @@ # # SPDX-License-Identifier: MPL-2.0 -import threading -import time -import cv2 import io -import os -import re +import warnings from PIL import Image +from arduino.app_peripherals.camera import Camera as Camera, CameraReadError as CRE, CameraOpenError as COE +from arduino.app_peripherals.camera.v4l_camera import V4LCamera +from arduino.app_utils.image import letterboxed, compressed_to_png from arduino.app_utils import Logger logger = Logger("USB Camera") +CameraReadError = CRE -class CameraReadError(Exception): - """Exception raised when the specified camera cannot be found.""" - - pass - - -class CameraOpenError(Exception): - """Exception raised when the camera cannot be opened.""" - - pass +CameraOpenError = COE +@warnings.deprecated("Use the Camera peripheral instead of this one") class USBCamera: """Represents an input peripheral for capturing images from a USB camera device. This class uses OpenCV to interface with the camera and capture images. @@ -34,7 +26,7 @@ class USBCamera: def __init__( self, camera: int = 0, - resolution: tuple[int, int] = (None, None), + resolution: tuple[int, int] = None, fps: int = 10, compression: bool = False, letterbox: bool = False, @@ -48,27 +40,15 @@ def __init__( compression (bool): Whether to compress the captured images. If True, images are compressed to PNG format. letterbox (bool): Whether to apply letterboxing to the captured images. """ - video_devices = self._get_video_devices_by_index() - if camera in video_devices: - self.camera = int(video_devices[camera]) - else: - raise CameraOpenError( - f"Not available camera at index 0 {camera}. Verify the connected cameras and fi cameras are listed " - f"inside devices listed here: /dev/v4l/by-id" - ) - - self.resolution = resolution - self.fps = fps self.compression = compression - self.letterbox = letterbox - self._cap = None - self._cap_lock = threading.Lock() - self._last_capture_time_monotonic = time.monotonic() - if self.fps > 0: - self.desired_interval = 1.0 / self.fps - else: - # Capture as fast as possible - self.desired_interval = 0 + + pipe = None + if compression: + pipe = compressed_to_png() + if letterbox: + pipe = pipe | letterboxed() if pipe else letterboxed() + + self._wrapped_camera = V4LCamera(camera, resolution, fps, pipe) def capture(self) -> Image.Image | None: """Captures a frame from the camera, blocking to respect the configured FPS. @@ -76,7 +56,7 @@ def capture(self) -> Image.Image | None: Returns: PIL.Image.Image | None: The captured frame as a PIL Image, or None if no frame is available. """ - image_bytes = self._extract_frame() + image_bytes = self._wrapped_camera.capture() if image_bytes is None: return None try: @@ -95,157 +75,18 @@ def capture_bytes(self) -> bytes | None: Returns: bytes | None: The captured frame as a bytes array, or None if no frame is available. """ - frame = self._extract_frame() + frame = self._wrapped_camera.capture() if frame is None: return None return frame.tobytes() - def _extract_frame(self) -> cv2.typing.MatLike | None: - # Without locking, 'elapsed_time' could be a stale value but this scenario is unlikely to be noticeable in - # practice, also its effects would disappear in the next capture. This optimization prevents us from calling - # time.sleep while holding a lock. - current_time_monotonic = time.monotonic() - elapsed_time = current_time_monotonic - self._last_capture_time_monotonic - if elapsed_time < self.desired_interval: - sleep_duration = self.desired_interval - elapsed_time - time.sleep(sleep_duration) # Keep time.sleep out of the locked section! - - with self._cap_lock: - if self._cap is None: - return None - - ret, bgr_frame = self._cap.read() - if not ret: - raise CameraReadError(f"Failed to read from camera {self.camera}.") - self._last_capture_time_monotonic = time.monotonic() - if bgr_frame is None: - # No frame available, skip this iteration - return None - - try: - if self.letterbox: - bgr_frame = self._letterbox(bgr_frame) - if self.compression: - success, rgb_frame = cv2.imencode(".png", bgr_frame) - if success: - return rgb_frame - else: - return None - else: - return cv2.cvtColor(bgr_frame, cv2.COLOR_BGR2RGB) - except cv2.error as e: - logger.exception(f"Error converting frame: {e}") - return None - - def _letterbox(self, frame: cv2.typing.MatLike) -> cv2.typing.MatLike: - """Applies letterboxing to the frame to make it square. - - Args: - frame (cv2.typing.MatLike): The input frame to be letterboxed (as cv2 supported format - numpy like). - - Returns: - cv2.typing.MatLike: The letterboxed frame (as cv2 supported format - numpy like). - """ - h, w = frame.shape[:2] - if w != h: - # Letterbox: add padding to make it square (yolo colors) - size = max(h, w) - return cv2.copyMakeBorder( - frame, - top=(size - h) // 2, - bottom=(size - h + 1) // 2, - left=(size - w) // 2, - right=(size - w + 1) // 2, - borderType=cv2.BORDER_CONSTANT, - value=(114, 114, 114), - ) - else: - return frame - - def _get_video_devices_by_index(self): - """Reads symbolic links in /dev/v4l/by-id/, resolves them, and returns a - dictionary mapping the numeric index to the system /dev/videoX device. - - Returns: - dict[int, str]: a dict where keys are ordinal integer indices (e.g., 0, 1) and values are the - /dev/videoX device names (e.g., "0", "1"). - """ - devices_by_index = {} - directory_path = "/dev/v4l/by-id/" - - # Check if the directory exists - if not os.path.exists(directory_path): - logger.error(f"Error: Directory '{directory_path}' not found.") - return devices_by_index - - try: - # List all entries in the directory - entries = os.listdir(directory_path) - - for entry in entries: - full_path = os.path.join(directory_path, entry) - - # Check if the entry is a symbolic link - if os.path.islink(full_path): - # Use a regular expression to find the numeric index at the end of the filename - match = re.search(r"index(\d+)$", entry) - if match: - index_str = match.group(1) - try: - index = int(index_str) - - # Resolve the symbolic link to its absolute path - resolved_path = os.path.realpath(full_path) - - # Get just the filename (e.g., "video0") from the resolved path - device_name = os.path.basename(resolved_path) - - # Remove the "video" prefix to get just the number - device_number = device_name.replace("video", "") - - # Add the index and device number to the dictionary - devices_by_index[index] = device_number - - except ValueError: - logger.warning(f"Warning: Could not convert index '{index_str}' to an integer for '{entry}'. Skipping.") - continue - except OSError as e: - logger.error(f"Error accessing directory '{directory_path}': {e}") - return devices_by_index - - return devices_by_index - def start(self): """Starts the camera capture.""" - with self._cap_lock: - if self._cap is not None: - return - - temp_cap = cv2.VideoCapture(self.camera) - if not temp_cap.isOpened(): - raise CameraOpenError(f"Failed to open camera {self.camera}.") - - self._cap = temp_cap # Assign only after successful initialization - self._last_capture_time_monotonic = time.monotonic() - - if self.resolution[0] is not None and self.resolution[1] is not None: - self._cap.set(cv2.CAP_PROP_FRAME_WIDTH, self.resolution[0]) - self._cap.set(cv2.CAP_PROP_FRAME_HEIGHT, self.resolution[1]) - # Verify if setting resolution was successful - actual_width = self._cap.get(cv2.CAP_PROP_FRAME_WIDTH) - actual_height = self._cap.get(cv2.CAP_PROP_FRAME_HEIGHT) - if actual_width != self.resolution[0] or actual_height != self.resolution[1]: - logger.warning( - f"Camera {self.camera} could not be set to {self.resolution[0]}x{self.resolution[1]}, " - f"actual resolution: {int(actual_width)}x{int(actual_height)}", - ) + self._wrapped_camera.start() def stop(self): """Stops the camera and releases its resources.""" - with self._cap_lock: - if self._cap is not None: - self._cap.release() - self._cap = None + self._wrapped_camera.stop() def produce(self): """Alias for capture method.""" diff --git a/src/arduino/app_utils/__init__.py b/src/arduino/app_utils/__init__.py index df9e4353..87d90aae 100644 --- a/src/arduino/app_utils/__init__.py +++ b/src/arduino/app_utils/__init__.py @@ -8,11 +8,9 @@ from .bridge import * from .folderwatch import * from .httprequest import * -from .image import * from .jsonparser import * from .logger import * from .slidingwindowbuffer import * -from .userinput import * __all__ = [ "App", @@ -23,12 +21,8 @@ "provide", "FolderWatcher", "HttpClient", - "draw_bounding_boxes", - "get_image_bytes", - "get_image_type", "JSONParser", "Logger", "SineGenerator", "SlidingWindowBuffer", - "UserTextInput", ] diff --git a/src/arduino/app_utils/image/__init__.py b/src/arduino/app_utils/image/__init__.py new file mode 100644 index 00000000..e9fc4196 --- /dev/null +++ b/src/arduino/app_utils/image/__init__.py @@ -0,0 +1,27 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +from .image import * +from .adjustments import * +from .pipeable import PipeableFunction + +__all__ = [ + "get_image_type", + "get_image_bytes", + "draw_bounding_boxes", + "draw_anomaly_markers", + "letterbox", + "resize", + "adjust", + "greyscale", + "compress_to_jpeg", + "compress_to_png", + "letterboxed", + "resized", + "adjusted", + "greyscaled", + "compressed_to_jpeg", + "compressed_to_png", + "PipeableFunction", +] diff --git a/src/arduino/app_utils/image/adjustments.py b/src/arduino/app_utils/image/adjustments.py new file mode 100644 index 00000000..97a63392 --- /dev/null +++ b/src/arduino/app_utils/image/adjustments.py @@ -0,0 +1,438 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +import cv2 +import numpy as np +from typing import Optional, Tuple +from PIL import Image + +from arduino.app_utils.image.pipeable import PipeableFunction + +# NOTE: we use the following formats for image shapes (H = height, W = width, C = channels): +# - When receiving a resolution as argument we expect (W, H) format which is more user-friendly +# - When receiving images we expect (H, W, C) format with C = BGR, BGRA or greyscale +# - When returning images we use (H, W, C) format with C = BGR, BGRA or greyscale (depending on input) +# Keep in mind OpenCV uses (W, H, C) format with C = BGR whereas numpy uses (H, W, C) format with any C. +# The below functions all support unsigned integer types used by OpenCV (uint8, uint16 and uint32). + + +""" +Image processing utilities handling common image operations like letterboxing, resizing, +adjusting, compressing and format conversions. +Frames are expected to be in BGR, BGRA or greyscale format. +""" + + +def letterbox( + frame: np.ndarray, + target_size: Optional[Tuple[int, int]] = None, + color: int | Tuple[int, int, int] = (114, 114, 114), + interpolation: int = cv2.INTER_LINEAR, +) -> np.ndarray: + """ + Add letterboxing to frame to achieve target size while maintaining aspect ratio. + + Args: + frame (np.ndarray): Input frame + target_size (tuple, optional): Target size as (width, height). If None, makes frame square. + color (int or tuple, optional): BGR color for padding borders, can be a scalar or a tuple + matching the frame's channel count. Default: (114, 114, 114) + interpolation (int, optional): OpenCV interpolation method. Default: cv2.INTER_LINEAR + + Returns: + np.ndarray: Letterboxed frame + """ + original_dtype = frame.dtype + orig_h, orig_w = frame.shape[:2] + + if target_size is None: + # Default to a square canvas based on the longest side + max_dim = max(orig_h, orig_w) + target_w, target_h = int(max_dim), int(max_dim) + else: + target_w, target_h = int(target_size[0]), int(target_size[1]) + + scale = min(target_w / orig_w, target_h / orig_h) + new_w = int(orig_w * scale) + new_h = int(orig_h * scale) + + if new_w == orig_w and new_h == orig_h: + resized_frame = frame + else: + resized_frame = cv2.resize(frame, (new_w, new_h), interpolation=interpolation) + + if frame.ndim == 2: + # Greyscale + if hasattr(color, "__len__"): + color = color[0] + canvas = np.full((target_h, target_w), color, dtype=original_dtype) + else: + # Colored (BGR/BGRA) + channels = frame.shape[2] + if not hasattr(color, "__len__"): + color = (color,) * channels + elif len(color) != channels: + raise ValueError(f"color length ({len(color)}) must match frame channels ({channels}).") + canvas = np.full((target_h, target_w, channels), color, dtype=original_dtype) + + # Calculate offsets to center the image + y_offset = (target_h - new_h) // 2 + x_offset = (target_w - new_w) // 2 + + # Paste the resized image onto the canvas + canvas[y_offset : y_offset + new_h, x_offset : x_offset + new_w] = resized_frame + + return canvas + + +def resize(frame: np.ndarray, target_size: Tuple[int, int], maintain_ratio: bool = False, interpolation: int = cv2.INTER_LINEAR) -> np.ndarray: + """ + Resize frame to target size. + + Args: + frame (np.ndarray): Input frame + target_size (tuple): Target size as (width, height) + maintain_ratio (bool): If True, use letterboxing to maintain aspect ratio. Default: False. + interpolation (int): OpenCV interpolation method. Default: cv2.INTER_LINEAR. + + Returns: + np.ndarray: Resized frame + """ + if frame.shape[1] == target_size[0] and frame.shape[0] == target_size[1]: + return frame + + if maintain_ratio: + return letterbox(frame, target_size) + else: + return cv2.resize(frame, (target_size[0], target_size[1]), interpolation=interpolation) + + +def adjust(frame: np.ndarray, brightness: float = 0.0, contrast: float = 1.0, saturation: float = 1.0, gamma: float = 1.0) -> np.ndarray: + """ + Apply image adjustments to a BGR or BGRA frame, preserving channel count + and data type. + + Args: + frame (np.ndarray): Input frame (uint8, uint16, uint32). + brightness (float): -1.0 to 1.0 (default: 0.0). + contrast (float): 0.0 to N (default: 1.0). + saturation (float): 0.0 to N (default: 1.0). + gamma (float): > 0 (default: 1.0). + + Returns: + np.ndarray: The adjusted input with same dtype as frame. + """ + original_dtype = frame.dtype + dtype_info = np.iinfo(original_dtype) + max_val = dtype_info.max + + # Use float64 for int types with > 24 bits of precision (e.g., uint32) + processing_dtype = np.float64 if dtype_info.bits > 24 else np.float32 + + # Apply the adjustments in float space to reduce clipping and data loss + frame_float = frame.astype(processing_dtype) / max_val + + # If present, separate alpha channel + alpha_channel = None + if frame.ndim == 3 and frame.shape[2] == 4: + alpha_channel = frame_float[:, :, 3] + frame_float = frame_float[:, :, :3] + + # Saturation + if saturation != 1.0 and frame.ndim == 3: # Ensure frame has color channels + # This must be done with float32 so it's lossy only for uint32 + frame_float_32 = frame_float.astype(np.float32) + hsv = cv2.cvtColor(frame_float_32, cv2.COLOR_BGR2HSV) + h, s, v = split_channels(hsv) + s = np.clip(s * saturation, 0.0, 1.0) + frame_float_32 = cv2.cvtColor(np.stack([h, s, v], axis=2), cv2.COLOR_HSV2BGR) + frame_float = frame_float_32.astype(processing_dtype) + + # Brightness + if brightness != 0.0: + frame_float = frame_float + brightness + + # Contrast + if contrast != 1.0: + frame_float = (frame_float - 0.5) * contrast + 0.5 + + # We need to clip before reaching gamma correction + # Clipping to 0 is mandatory to avoid handling complex numbers + # Clipping to 1 is handy to avoid clipping again after gamma correction + frame_float = np.clip(frame_float, 0.0, 1.0) + + # Gamma + if gamma != 1.0: + if gamma <= 0: + # This check is critical to prevent math errors (NaN/Inf) + raise ValueError("Gamma value must be greater than 0.") + frame_float = np.power(frame_float, gamma) + + # Convert back to original dtype + final_frame_bgr = (frame_float * max_val).astype(original_dtype) + + # If present, reattach alpha channel + if alpha_channel is not None: + final_alpha = (alpha_channel * max_val).astype(original_dtype) + b, g, r = split_channels(final_frame_bgr) + final_frame = np.stack([b, g, r, final_alpha], axis=2) + else: + final_frame = final_frame_bgr + + return final_frame + + +def split_channels(frame: np.ndarray) -> tuple: + """ + Split a multi-channel frame into individual channels using numpy indexing. + This function provides better data type compatibility than cv2.split, + especially for uint32 data which OpenCV doesn't fully support. + + Args: + frame (np.ndarray): Input frame with 3 or 4 channels + + Returns: + tuple: Individual channel arrays. For BGR: (b, g, r). For BGRA: (b, g, r, a). + For HSV: (h, s, v). For other 3-channel: (ch0, ch1, ch2). + """ + if frame.ndim != 3: + raise ValueError("Frame must be 3-dimensional (H, W, C)") + + channels = frame.shape[2] + if channels == 3: + return frame[:, :, 0], frame[:, :, 1], frame[:, :, 2] + elif channels == 4: + return frame[:, :, 0], frame[:, :, 1], frame[:, :, 2], frame[:, :, 3] + else: + raise ValueError(f"Unsupported number of channels: {channels}. Expected 3 or 4.") + + +def greyscale(frame: np.ndarray) -> np.ndarray: + """ + Converts a BGR or BGRA frame to greyscale, preserving channel count and + data type. A greyscale frame is returned unmodified. + + Args: + frame (np.ndarray): Input frame (uint8, uint16, uint32). + + Returns: + np.ndarray: The greyscaled frame with same dtype and channel count as frame. + """ + # If already greyscale or unknown format, return the original frame + if frame.ndim != 3: + return frame + + original_dtype = frame.dtype + dtype_info = np.iinfo(original_dtype) + max_val = dtype_info.max + + # Use float64 for int types with > 24 bits of precision (e.g., uint32) + processing_dtype = np.float64 if dtype_info.bits > 24 else np.float32 + + # Apply the adjustments in float space to reduce clipping and data loss + frame_float = frame.astype(processing_dtype) / max_val + + # If present, separate alpha channel + alpha_channel = None + if frame.shape[2] == 4: + alpha_channel = frame_float[:, :, 3] + frame_float = frame_float[:, :, :3] + + # Convert to greyscale using standard BT.709 weights + # GREY = 0.0722 * B + 0.7152 * G + 0.2126 * R + grey_float = 0.0722 * frame_float[:, :, 0] + 0.7152 * frame_float[:, :, 1] + 0.2126 * frame_float[:, :, 2] + + # Convert back to original dtype + final_grey = (grey_float * max_val).astype(original_dtype) + + # If present, reattach alpha channel + if alpha_channel is not None: + final_alpha = (alpha_channel * max_val).astype(original_dtype) + final_frame = np.stack([final_grey, final_grey, final_grey, final_alpha], axis=2) + else: + final_frame = np.stack([final_grey, final_grey, final_grey], axis=2) + + return final_frame + + +def compress_to_jpeg(frame: np.ndarray, quality: int = 80) -> Optional[np.ndarray]: + """ + Compress frame to JPEG format. + + Args: + frame (np.ndarray): Input frame as numpy array + quality (int): JPEG quality (0-100, higher = better quality) + + Returns: + bytes: Compressed JPEG data, or None if compression failed + """ + quality = int(quality) # Gstreamer doesn't like quality to be float + try: + success, encoded = cv2.imencode(".jpg", frame, [cv2.IMWRITE_JPEG_QUALITY, quality]) + return encoded if success else None + except Exception: + return None + + +def compress_to_png(frame: np.ndarray, compression_level: int = 6) -> Optional[np.ndarray]: + """ + Compress frame to PNG format. + + Args: + frame (np.ndarray): Input frame as numpy array + compression_level (int): PNG compression level (0-9, higher = better compression) + + Returns: + bytes: Compressed PNG data, or None if compression failed + """ + compression_level = int(compression_level) # Gstreamer doesn't like compression_level to be float + try: + success, encoded = cv2.imencode(".png", frame, [cv2.IMWRITE_PNG_COMPRESSION, compression_level]) + return encoded if success else None + except Exception: + return None + + +def numpy_to_pil(frame: np.ndarray) -> Image.Image: + """ + Convert numpy array to PIL Image. + + Args: + frame (np.ndarray): Input frame in BGR format + + Returns: + PIL.Image.Image: PIL Image in RGB format + """ + # Convert BGR to RGB + rgb_frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + return Image.fromarray(rgb_frame) + + +def pil_to_numpy(image: Image.Image) -> np.ndarray: + """ + Convert PIL Image to numpy array. + + Args: + image (PIL.Image.Image): PIL Image + + Returns: + np.ndarray: Numpy array in BGR format + """ + if image.mode != "RGB": + image = image.convert("RGB") + + # Convert to numpy and then BGR + rgb_array = np.array(image) + return cv2.cvtColor(rgb_array, cv2.COLOR_RGB2BGR) + + +# ============================================================================= +# Functional API - Standalone pipeable functions +# ============================================================================= + + +def letterboxed(target_size: Optional[Tuple[int, int]] = None, color: Tuple[int, int, int] = (114, 114, 114), interpolation: int = cv2.INTER_LINEAR): + """ + Pipeable letterbox function - apply letterboxing with pipe operator support. + + Args: + target_size (tuple, optional): Target size as (width, height). If None, makes frame square. + color (tuple): RGB color for padding borders. Default: (114, 114, 114) + interpolation (int): OpenCV interpolation method. Default: cv2.INTER_LINEAR + + Returns: + Partial function that takes a frame and returns letterboxed frame + + Examples: + pipe = letterboxed(target_size=(640, 640)) + pipe = letterboxed() | greyscaled() + """ + return PipeableFunction(letterbox, target_size=target_size, color=color, interpolation=interpolation) + + +def resized(target_size: Tuple[int, int], maintain_ratio: bool = False, interpolation: int = cv2.INTER_LINEAR): + """ + Pipeable resize function - resize frame with pipe operator support. + + Args: + target_size (tuple): Target size as (width, height) + maintain_ratio (bool): If True, use letterboxing to maintain aspect ratio + interpolation (int): OpenCV interpolation method. Default: cv2.INTER_LINEAR + + Returns: + Partial function that takes a frame and returns resized frame + + Examples: + pipe = resized(target_size=(640, 480)) + pipe = letterboxed() | resized(target_size=(320, 240)) + """ + return PipeableFunction(resize, target_size=target_size, maintain_ratio=maintain_ratio, interpolation=interpolation) + + +def adjusted(brightness: float = 0.0, contrast: float = 1.0, saturation: float = 1.0, gamma: float = 1.0): + """ + Pipeable adjust function - apply image adjustments with pipe operator support. + + Args: + brightness (float): -1.0 to 1.0 (default: 0.0). + contrast (float): 0.0 to N (default: 1.0). + saturation (float): 0.0 to N (default: 1.0). + gamma (float): > 0 (default: 1.0). + + Returns: + Partial function that takes a frame and returns adjusted frame + + Examples: + pipe = adjusted(brightness=0.1, contrast=1.2) + pipe = letterboxed() | adjusted(saturation=0.8) + """ + return PipeableFunction(adjust, brightness=brightness, contrast=contrast, saturation=saturation, gamma=gamma) + + +def greyscaled(): + """ + Pipeable greyscale function - convert frame to greyscale with pipe operator support. + + Returns: + Function that takes a frame and returns greyscale frame + + Examples: + pipe = greyscaled() + pipe = letterboxed() | greyscaled() + """ + return PipeableFunction(greyscale) + + +def compressed_to_jpeg(quality: int = 80): + """ + Pipeable JPEG compression function - compress frame to JPEG with pipe operator support. + + Args: + quality (int): JPEG quality (0-100, higher = better quality) + + Returns: + Partial function that takes a frame and returns compressed JPEG bytes as Numpy array or None + + Examples: + pipe = compressed_to_jpeg(quality=95) + pipe = resized(target_size=(640, 480)) | compressed_to_jpeg() + """ + return PipeableFunction(compress_to_jpeg, quality=quality) + + +def compressed_to_png(compression_level: int = 6): + """ + Pipeable PNG compression function - compress frame to PNG with pipe operator support. + + Args: + compression_level (int): PNG compression level (0-9, higher = better compression) + + Returns: + Partial function that takes a frame and returns compressed PNG bytes as Numpy array or None + + Examples: + pipe = compressed_to_png(compression_level=9) + pipe = letterboxed() | compressed_to_png() + """ + return PipeableFunction(compress_to_png, compression_level=compression_level) diff --git a/src/arduino/app_utils/image.py b/src/arduino/app_utils/image/image.py similarity index 87% rename from src/arduino/app_utils/image.py rename to src/arduino/app_utils/image/image.py index 8870f9b1..e24aa397 100644 --- a/src/arduino/app_utils/image.py +++ b/src/arduino/app_utils/image/image.py @@ -35,7 +35,7 @@ def _read(file_path: str) -> bytes: with open(file_path, "rb") as f: return f.read() except Exception as e: - logger(f"Error reading image: {e}") + logger.error(f"Error reading image: {e}") return None @@ -78,22 +78,6 @@ def get_image_bytes(image: str | Image.Image | bytes) -> bytes: return None -def draw_colored_dot(draw, x, y, color, size): - """Draws a large colored dot on a PIL Image at the specified coordinate. - - Args: - draw: An ImageDraw object from PIL. - x: The x-coordinate of the center of the dot. - y: The y-coordinate of the center of the dot. - color: A color value that PIL understands (e.g., "red", (255, 0, 0), "#FF0000"). - size: The radius of the dot (in pixels). - """ - # Calculate the bounding box for the circle - bounding_box = (x - size, y - size, x + size, y + size) - # Draw a filled ellipse (which looks like a circle if the bounding box is a square) - draw.ellipse(bounding_box, fill=color) - - def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: ImageDraw.ImageDraw = None) -> Image.Image | None: """Draw bounding boxes on an image using PIL. @@ -178,11 +162,7 @@ def draw_bounding_boxes(image: Image.Image | bytes, detection: dict, draw: Image return image_box -def draw_anomaly_markers( - image: Image.Image | bytes, - detection: dict, - draw: ImageDraw.ImageDraw = None, -) -> Image.Image | None: +def draw_anomaly_markers(image: Image.Image | bytes, detection: dict, draw: ImageDraw.ImageDraw = None) -> Image.Image | None: """Draw bounding boxes on an image using PIL. The thickness of the box and font size are scaled based on image size. @@ -192,9 +172,6 @@ def draw_anomaly_markers( detection (dict): A dictionary containing detection results with keys 'class_name', 'bounding_box_xyxy', and 'score'. draw (ImageDraw.ImageDraw, optional): An existing ImageDraw object to use. If None, a new one is created. - label_above_box (bool, optional): If True, labels are drawn above the bounding box. Defaults to False. - colours (list, optional): List of colors to use for bounding boxes. Defaults to a predefined palette. - text_color (str, optional): Color of the text labels. Defaults to "white". """ if isinstance(image, bytes): image_box = Image.open(io.BytesIO(image)) diff --git a/src/arduino/app_utils/image/pipeable.py b/src/arduino/app_utils/image/pipeable.py new file mode 100644 index 00000000..86e0bad9 --- /dev/null +++ b/src/arduino/app_utils/image/pipeable.py @@ -0,0 +1,99 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +""" +Decorator for adding pipe operator support to transformation functions. + +This module provides a decorator that wraps static functions to support +the | (pipe) operator for functional composition. + +Note: Due to numpy's element-wise operator behavior, using the pipe operator +with numpy arrays (array | function) is not supported. Use function(array) instead. +""" + +from typing import Callable + + +class PipeableFunction: + """ + Wrapper class that adds pipe operator support to a function. + + This allows functions to be composed using the | operator in a left-to-right manner. + """ + + def __init__(self, func: Callable, *args, **kwargs): + """ + Initialize a pipeable function. + + Args: + func: The function to wrap + *args: Positional arguments to partially apply + **kwargs: Keyword arguments to partially apply + """ + self.func = func + self.args = args + self.kwargs = kwargs + + def __call__(self, *args, **kwargs): + """Call the wrapped function with combined arguments.""" + combined_args = self.args + args + combined_kwargs = {**self.kwargs, **kwargs} + return self.func(*combined_args, **combined_kwargs) + + def __ror__(self, other): + """ + Right-hand side of pipe operator (|). + + This allows: value | pipeable_function + + Args: + other: The value being piped into this function + + Returns: + Result of applying this function to the value + """ + return self(other) + + def __or__(self, other): + """ + Left-hand side of pipe operator (|). + + This allows: pipeable_function | other_function + + Args: + other: Another function to compose with + + Returns: + A new pipeable function that combines both + """ + if not callable(other): + # Raise TypeError immediately instead of returning NotImplemented + # This prevents Python from trying the reverse operation for nothing + raise TypeError(f"unsupported operand type(s) for |: '{type(self).__name__}' and '{type(other).__name__}'") + + def composed(value): + return other(self(value)) + + return PipeableFunction(composed) + + def __repr__(self): + """String representation of the pipeable function.""" + # Get function name safely + func_name = getattr(self.func, "__name__", None) + if func_name is None: + func_name = getattr(type(self.func), "__name__", None) + if func_name is None: + from functools import partial + + if type(self.func) is partial: + func_name = "partial" + if func_name is None: + func_name = "unknown" # Fallback + + if self.args or self.kwargs: + args_str = ", ".join(map(str, self.args)) + kwargs_str = ", ".join(f"{k}={v}" for k, v in self.kwargs.items()) + all_args = ", ".join(filter(None, [args_str, kwargs_str])) + return f"{func_name}({all_args})" + return f"{func_name}()" diff --git a/src/arduino/app_utils/userinput.py b/src/arduino/app_utils/userinput.py deleted file mode 100644 index 530b978e..00000000 --- a/src/arduino/app_utils/userinput.py +++ /dev/null @@ -1,14 +0,0 @@ -# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA -# -# SPDX-License-Identifier: MPL-2.0 - - -class UserTextInput: - def __init__(self, prompt: str): - self.prompt = prompt - - def get(self): - return input(self.prompt) - - def produce(self): - return input(self.prompt) diff --git a/tests/arduino/app_bricks/imageclassification/test_imageclassification.py b/tests/arduino/app_bricks/imageclassification/test_imageclassification.py index 6f748561..19b6fa0a 100644 --- a/tests/arduino/app_bricks/imageclassification/test_imageclassification.py +++ b/tests/arduino/app_bricks/imageclassification/test_imageclassification.py @@ -16,7 +16,7 @@ def mock_dependencies(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda x: [(None, None), (None, "8200")]) # make get_image_bytes a no-op for raw bytes monkeypatch.setattr( - "arduino.app_utils.get_image_bytes", + "arduino.app_utils.image.get_image_bytes", lambda x: x if isinstance(x, (bytes, bytearray)) else None, ) diff --git a/tests/arduino/app_bricks/objectdetection/test_objectdetection.py b/tests/arduino/app_bricks/objectdetection/test_objectdetection.py index f98f779f..8d6b714a 100644 --- a/tests/arduino/app_bricks/objectdetection/test_objectdetection.py +++ b/tests/arduino/app_bricks/objectdetection/test_objectdetection.py @@ -4,8 +4,6 @@ import pytest from pathlib import Path -import io -from PIL import Image from arduino.app_bricks.object_detection import ObjectDetection @@ -113,27 +111,6 @@ def fake_post( assert result["detection"] == [{"class_name": "C", "confidence": "50.00", "bounding_box_xyxy": [1.0, 2.0, 4.0, 6.0]}] -def test_draw_bounding_boxes(detector: ObjectDetection): - """Test the draw_bounding_boxes method with a valid image and detection. - - This test checks if the method returns a PIL Image object. - - Args: - detector (ObjectDetection): An instance of the ObjectDetection class. - """ - img = Image.new("RGB", (20, 20), color="white") - det = {"detection": [{"class_name": "X", "bounding_box_xyxy": [2, 2, 10, 10], "confidence": "50.0"}]} - - out = detector.draw_bounding_boxes(img, det) - assert isinstance(out, Image.Image) - - buf = io.BytesIO() - img.save(buf, format="PNG") - raw = buf.getvalue() - out2 = detector.draw_bounding_boxes(raw, det) - assert isinstance(out2, Image.Image) - - def test_process(monkeypatch: pytest.MonkeyPatch, tmp_path: Path, detector: ObjectDetection): """Test the process method with a valid file path. diff --git a/tests/arduino/app_core/test_edge_impulse.py b/tests/arduino/app_core/test_edge_impulse.py index bf3d8985..0e53f110 100644 --- a/tests/arduino/app_core/test_edge_impulse.py +++ b/tests/arduino/app_core/test_edge_impulse.py @@ -24,7 +24,7 @@ def mock_infra(monkeypatch: pytest.MonkeyPatch): monkeypatch.setattr("arduino.app_internal.core.resolve_address", lambda h: "127.0.0.1") monkeypatch.setattr("arduino.app_internal.core.parse_docker_compose_variable", lambda s: [(None, None), (None, "1337")]) # identity for get_image_bytes - monkeypatch.setattr("arduino.app_utils.get_image_bytes", lambda b: b) + monkeypatch.setattr("arduino.app_utils.image.get_image_bytes", lambda b: b) @pytest.fixture diff --git a/tests/arduino/app_utils/image/test_adjustments.py b/tests/arduino/app_utils/image/test_adjustments.py new file mode 100644 index 00000000..10b742a5 --- /dev/null +++ b/tests/arduino/app_utils/image/test_adjustments.py @@ -0,0 +1,442 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +import numpy as np +import pytest +from arduino.app_utils.image.adjustments import letterbox, resize, adjust, split_channels, greyscale + + +# FIXTURES + + +def create_gradient_frame(dtype): + """Helper: Creates a 100x100 3-channel (BGR) frame with gradients.""" + iinfo = np.iinfo(dtype) + max_val = iinfo.max + frame = np.zeros((100, 100, 3), dtype=dtype) + frame[:, :, 0] = np.linspace(0, max_val // 2, 100, dtype=dtype) # Blue + frame[:, :, 1] = np.linspace(0, max_val, 100, dtype=dtype) # Green + frame[:, :, 2] = np.linspace(max_val // 2, max_val, 100, dtype=dtype) # Red + return frame + + +def create_greyscale_frame(dtype): + """Helper: Creates a 100x100 1-channel (greyscale) frame.""" + iinfo = np.iinfo(dtype) + max_val = iinfo.max + frame = np.zeros((100, 100), dtype=dtype) + frame[:, :] = np.linspace(0, max_val, 100, dtype=dtype) + return frame + + +def create_bgra_frame(dtype): + """Helper: Creates a 100x100 4-channel (BGRA) frame.""" + iinfo = np.iinfo(dtype) + max_val = iinfo.max + bgr = create_gradient_frame(dtype) + alpha = np.zeros((100, 100), dtype=dtype) + alpha[:, :] = np.linspace(max_val // 4, max_val, 100, dtype=dtype) + frame = np.stack([bgr[:, :, 0], bgr[:, :, 1], bgr[:, :, 2], alpha], axis=2) + return frame + + +# Fixture for a 100x100 uint8 BGR frame +@pytest.fixture +def frame_bgr_uint8(): + return create_gradient_frame(np.uint8) + + +# Fixture for a 100x100 uint8 BGRA frame +@pytest.fixture +def frame_bgra_uint8(): + return create_bgra_frame(np.uint8) + + +# Fixture for a 100x100 uint8 greyscale frame +@pytest.fixture +def frame_grey_uint8(): + return create_greyscale_frame(np.uint8) + + +# Fixtures for high bit-depth frames +@pytest.fixture +def frame_bgr_uint16(): + return create_gradient_frame(np.uint16) + + +@pytest.fixture +def frame_bgr_uint32(): + return create_gradient_frame(np.uint32) + + +@pytest.fixture +def frame_bgra_uint16(): + return create_bgra_frame(np.uint16) + + +@pytest.fixture +def frame_bgra_uint32(): + return create_bgra_frame(np.uint32) + + +# Fixture for a 200x100 (wide) uint8 BGR frame +@pytest.fixture +def frame_bgr_wide(): + frame = np.zeros((100, 200, 3), dtype=np.uint8) + frame[:, :, 2] = 255 # Solid Red + return frame + + +# Fixture for a 100x200 (tall) uint8 BGR frame +@pytest.fixture +def frame_bgr_tall(): + frame = np.zeros((200, 100, 3), dtype=np.uint8) + frame[:, :, 1] = 255 # Solid Green + return frame + + +# A parameterized fixture to test multiple data types +@pytest.fixture(params=[np.uint8, np.uint16, np.uint32]) +def frame_any_dtype(request): + """Provides a gradient frame for uint8, uint16, and uint32.""" + return create_gradient_frame(request.param) + + +# TESTS + + +def test_adjust_dtype_preservation(frame_any_dtype): + """Tests that the dtype of the frame is preserved.""" + dtype = frame_any_dtype.dtype + adjusted = adjust(frame_any_dtype, brightness=0.1) + assert adjusted.dtype == dtype + + +def test_adjust_no_op(frame_bgr_uint8): + """Tests that default parameters do not change the frame.""" + adjusted = adjust(frame_bgr_uint8) + assert np.array_equal(frame_bgr_uint8, adjusted) + + +def test_adjust_brightness(frame_bgr_uint8): + """Tests brightness adjustment.""" + brighter = adjust(frame_bgr_uint8, brightness=0.1) + darker = adjust(frame_bgr_uint8, brightness=-0.1) + assert np.mean(brighter) > np.mean(frame_bgr_uint8) + assert np.mean(darker) < np.mean(frame_bgr_uint8) + + +def test_adjust_contrast(frame_bgr_uint8): + """Tests contrast adjustment.""" + higher_contrast = adjust(frame_bgr_uint8, contrast=1.5) + lower_contrast = adjust(frame_bgr_uint8, contrast=0.5) + assert np.std(higher_contrast) > np.std(frame_bgr_uint8) + assert np.std(lower_contrast) < np.std(frame_bgr_uint8) + + +def test_adjust_gamma(frame_bgr_uint8): + """Tests gamma correction.""" + # Gamma < 1.0 (e.g., 0.5) ==> brightens + brighter = adjust(frame_bgr_uint8, gamma=0.5) + # Gamma > 1.0 (e.g., 2.0) ==> darkens + darker = adjust(frame_bgr_uint8, gamma=2.0) + assert np.mean(brighter) > np.mean(frame_bgr_uint8) + assert np.mean(darker) < np.mean(frame_bgr_uint8) + + +def test_adjust_saturation_to_greyscale(frame_bgr_uint8): + """Tests that saturation=0.0 makes all color channels equal.""" + desaturated = adjust(frame_bgr_uint8, saturation=0.0) + b, g, r = split_channels(desaturated) + assert np.allclose(b, g, atol=1) + assert np.allclose(g, r, atol=1) + + +def test_adjust_greyscale_input(frame_grey_uint8): + """Tests that greyscale frames are handled safely.""" + adjusted = adjust(frame_grey_uint8, saturation=1.5, brightness=0.1) + assert adjusted.ndim == 2 + assert adjusted.dtype == np.uint8 + assert np.mean(adjusted) > np.mean(frame_grey_uint8) + + +def test_adjust_bgra_input(frame_bgra_uint8): + """Tests that BGRA frames are handled safely and alpha is preserved.""" + original_alpha = frame_bgra_uint8[:, :, 3] + + adjusted = adjust(frame_bgra_uint8, saturation=0.0, brightness=0.1) + + assert adjusted.ndim == 3 + assert adjusted.shape[2] == 4 + assert adjusted.dtype == np.uint8 + + b, g, r, a = split_channels(adjusted) + assert np.allclose(b, g, atol=1) # Check desaturation + assert np.allclose(g, r, atol=1) # Check desaturation + assert np.array_equal(original_alpha, a) # Check alpha preservation + + +def test_adjust_gamma_zero_error(frame_bgr_uint8): + """Tests that gamma <= 0 raises a ValueError.""" + with pytest.raises(ValueError, match="Gamma value must be greater than 0."): + adjust(frame_bgr_uint8, gamma=0.0) + + with pytest.raises(ValueError, match="Gamma value must be greater than 0."): + adjust(frame_bgr_uint8, gamma=-1.0) + + +def test_adjust_high_bit_depth_bgr(frame_bgr_uint16, frame_bgr_uint32): + """ + Tests that brightness/contrast logic is correct on high bit-depth images. + This validates that the float64 conversion is working. + """ + # Test uint16 + brighter_16 = adjust(frame_bgr_uint16, brightness=0.1) + darker_16 = adjust(frame_bgr_uint16, brightness=-0.1) + assert np.mean(brighter_16) > np.mean(frame_bgr_uint16) + assert np.mean(darker_16) < np.mean(frame_bgr_uint16) + + # Test uint32 + brighter_32 = adjust(frame_bgr_uint32, brightness=0.1) + darker_32 = adjust(frame_bgr_uint32, brightness=-0.1) + assert np.mean(brighter_32) > np.mean(frame_bgr_uint32) + assert np.mean(darker_32) < np.mean(frame_bgr_uint32) + + +def test_adjust_high_bit_depth_bgra(frame_bgra_uint16, frame_bgra_uint32): + """ + Tests that brightness/contrast logic is correct on high bit-depth + BGRA images and that the alpha channel is preserved. + """ + # Test uint16 + original_alpha_16 = frame_bgra_uint16[:, :, 3] + brighter_16 = adjust(frame_bgra_uint16, brightness=0.1) + assert brighter_16.dtype == np.uint16 + assert brighter_16.shape == frame_bgra_uint16.shape + _, _, _, a16 = split_channels(brighter_16) + assert np.array_equal(original_alpha_16, a16) + assert np.mean(brighter_16) > np.mean(frame_bgra_uint16) + + # Test uint32 + original_alpha_32 = frame_bgra_uint32[:, :, 3] + brighter_32 = adjust(frame_bgra_uint32, brightness=0.1) + assert brighter_32.dtype == np.uint32 + assert brighter_32.shape == frame_bgra_uint32.shape + _, _, _, a32 = split_channels(brighter_32) + assert np.array_equal(original_alpha_32, a32) + assert np.mean(original_alpha_32) > np.mean(frame_bgra_uint32) + + +def test_greyscale(frame_bgr_uint8, frame_bgra_uint8, frame_grey_uint8): + """Tests the standalone greyscale function.""" + # Test on BGR + greyscaled_bgr = greyscale(frame_bgr_uint8) + assert greyscaled_bgr.ndim == 3 + assert greyscaled_bgr.shape[2] == 3 + b, g, r = split_channels(greyscaled_bgr) + assert np.allclose(b, g, atol=1) + assert np.allclose(g, r, atol=1) + + # Test on BGRA + original_alpha = frame_bgra_uint8[:, :, 3] + greyscaled_bgra = greyscale(frame_bgra_uint8) + assert greyscaled_bgra.ndim == 3 + assert greyscaled_bgra.shape[2] == 4 + b, g, r, a = split_channels(greyscaled_bgra) + assert np.allclose(b, g, atol=1) + assert np.allclose(g, r, atol=1) + assert np.array_equal(original_alpha, a) + + # Test on 2D Greyscale (should be no-op) + greyscaled_grey = greyscale(frame_grey_uint8) + assert np.array_equal(frame_grey_uint8, greyscaled_grey) + assert greyscaled_grey.ndim == 2 + + +def test_greyscale_dtype_preservation(frame_any_dtype): + """Tests that the dtype of the frame is preserved.""" + dtype = frame_any_dtype.dtype + adjusted = adjust(frame_any_dtype, brightness=0.1) + assert adjusted.dtype == dtype + + +def test_greyscale_high_bit_depth(frame_bgr_uint16, frame_bgr_uint32): + """ + Tests that greyscale logic is correct on high bit-depth images. + """ + # Test uint16 + greyscaled_16 = greyscale(frame_bgr_uint16) + assert greyscaled_16.dtype == np.uint16 + assert greyscaled_16.shape == frame_bgr_uint16.shape + b16, g16, r16 = split_channels(greyscaled_16) + assert np.allclose(b16, g16, atol=1) + assert np.allclose(g16, r16, atol=1) + assert np.mean(b16) != np.mean(frame_bgr_uint16[:, :, 0]) + + # Test uint32 + greyscaled_32 = greyscale(frame_bgr_uint32) + assert greyscaled_32.dtype == np.uint32 + assert greyscaled_32.shape == frame_bgr_uint32.shape + b32, g32, r32 = split_channels(greyscaled_32) + assert np.allclose(b32, g32, atol=1) + assert np.allclose(g32, r32, atol=1) + assert np.mean(b32) != np.mean(frame_bgr_uint32[:, :, 0]) + + +def test_high_bit_depth_greyscale_bgra_content(frame_bgra_uint16, frame_bgra_uint32): + """ + Tests that greyscale logic is correct on high bit-depth + BGRA images and that the alpha channel is preserved. + """ + # Test uint16 + original_alpha_16 = frame_bgra_uint16[:, :, 3] + greyscaled_16 = greyscale(frame_bgra_uint16) + assert greyscaled_16.dtype == np.uint16 + assert greyscaled_16.shape == frame_bgra_uint16.shape + b16, g16, r16, a16 = split_channels(greyscaled_16) + assert np.allclose(b16, g16, atol=1) + assert np.allclose(g16, r16, atol=1) + assert np.array_equal(original_alpha_16, a16) + + # Test uint32 + original_alpha_32 = frame_bgra_uint32[:, :, 3] + greyscaled_32 = greyscale(frame_bgra_uint32) + assert greyscaled_32.dtype == np.uint32 + assert greyscaled_32.shape == frame_bgra_uint32.shape + b32, g32, r32, a32 = split_channels(greyscaled_32) + assert np.allclose(b32, g32, atol=1) + assert np.allclose(g32, r32, atol=1) + assert np.array_equal(original_alpha_32, a32) + + +def test_resize_shape_and_dtype(frame_bgr_uint8, frame_bgra_uint8, frame_grey_uint8): + """Tests that resize produces the correct shape and preserves dtype.""" + target_w, target_h = 50, 75 + + # Test BGR + resized_bgr = resize(frame_bgr_uint8, (target_w, target_h)) + assert resized_bgr.shape == (target_h, target_w, 3) + assert resized_bgr.dtype == frame_bgr_uint8.dtype + + # Test BGRA + resized_bgra = resize(frame_bgra_uint8, (target_w, target_h)) + assert resized_bgra.shape == (target_h, target_w, 4) + assert resized_bgra.dtype == frame_bgra_uint8.dtype + + # Test Greyscale + resized_grey = resize(frame_grey_uint8, (target_w, target_h)) + assert resized_grey.shape == (target_h, target_w) + assert resized_grey.dtype == frame_grey_uint8.dtype + + +def test_letterbox_wide_image(frame_bgr_wide): + """Tests letterboxing a wide image (200x100) into a square (200x200).""" + target_w, target_h = 200, 200 + # Frame is 200x100, solid red (255) + # Scale = min(200/200, 200/100) = min(1, 2) = 1 + # new_w = 200 * 1 = 200 + # new_h = 100 * 1 = 100 + # y_offset = (200 - 100) // 2 = 50 + # x_offset = (200 - 200) // 2 = 0 + + letterboxed = letterbox(frame_bgr_wide, (target_w, target_h), color=0) + + assert letterboxed.shape == (target_h, target_w, 3) + assert letterboxed.dtype == frame_bgr_wide.dtype + + # Check padding (top row, black) + assert np.all(letterboxed[0, 0] == [0, 0, 0]) + # Check padding (bottom row, black) + assert np.all(letterboxed[199, 199] == [0, 0, 0]) + # Check image data (center row, red) + assert np.all(letterboxed[100, 100] == [0, 0, 255]) + # Check image edge (no left/right padding) + assert np.all(letterboxed[100, 0] == [0, 0, 255]) + + +def test_letterbox_tall_image(frame_bgr_tall): + """Tests letterboxing a tall image (100x200) into a square (200x200).""" + target_w, target_h = 200, 200 + # Frame is 100x200, solid green (255) + # Scale = min(200/100, 200/200) = min(2, 1) = 1 + # new_w = 100 * 1 = 100 + # new_h = 200 * 1 = 200 + # y_offset = (200 - 200) // 2 = 0 + # x_offset = (200 - 100) // 2 = 50 + + letterboxed = letterbox(frame_bgr_tall, (target_w, target_h), color=0) + + assert letterboxed.shape == (target_h, target_w, 3) + assert letterboxed.dtype == frame_bgr_tall.dtype + + # Check padding (left column, black) + assert np.all(letterboxed[0, 0] == [0, 0, 0]) + # Check padding (right column, black) + assert np.all(letterboxed[199, 199] == [0, 0, 0]) + # Check image data (center column, green) + assert np.all(letterboxed[100, 100] == [0, 255, 0]) + # Check image edge (no top/bottom padding) + assert np.all(letterboxed[0, 100] == [0, 255, 0]) + + +def test_letterbox_color(frame_bgr_tall): + """Tests letterboxing with a non-default color.""" + white = (255, 255, 255) + letterboxed = letterbox(frame_bgr_tall, (200, 200), color=white) + + # Check padding (left column, white) + assert np.all(letterboxed[0, 0] == white) + # Check image data (center column, green) + assert np.all(letterboxed[100, 100] == [0, 255, 0]) + + +def test_letterbox_bgra(frame_bgra_uint8): + """Tests letterboxing on a 4-channel BGRA image.""" + target_w, target_h = 200, 200 + # Opaque black padding + padding = (0, 0, 0, 255) + + letterboxed = letterbox(frame_bgra_uint8, (target_w, target_h), color=padding) + + assert letterboxed.shape == (target_h, target_w, 4) + # Check no padding (corner, original BGRA point) + assert np.array_equal(letterboxed[0, 0], frame_bgra_uint8[0, 0]) + # Check image data (center, from fixture) - allow small tolerance for numerical precision differences + assert np.allclose(letterboxed[100, 100], frame_bgra_uint8[50, 50], atol=1) + + +def test_letterbox_greyscale(frame_grey_uint8): + """Tests letterboxing on a 2D greyscale image.""" + target_w, target_h = 200, 200 + letterboxed = letterbox(frame_grey_uint8, (target_w, target_h), color=0) + + assert letterboxed.shape == (target_h, target_w) + assert letterboxed.ndim == 2 + # Check padding (corner, black) + assert letterboxed[0, 0] == 0 + # Check image data (center) - allow small tolerance for numerical precision differences + assert np.allclose(letterboxed[100, 100], frame_grey_uint8[50, 50], atol=1) + + +def test_letterbox_none_target_size(frame_bgr_wide, frame_bgr_tall): + """Tests that target_size=None creates a square based on the longest side.""" + # frame_bgr_wide is 200x100, longest side is 200 + letterboxed_wide = letterbox(frame_bgr_wide, target_size=None) + assert letterboxed_wide.shape == (200, 200, 3) + + # frame_bgr_tall is 100x200, longest side is 200 + letterboxed_tall = letterbox(frame_bgr_tall, target_size=None) + assert letterboxed_tall.shape == (200, 200, 3) + + +def test_letterbox_color_tuple_error(frame_bgr_uint8): + """Tests that a mismatched padding tuple raises a ValueError.""" + with pytest.raises(ValueError, match="color length"): + # BGR (3-ch) frame with 4-ch padding + letterbox(frame_bgr_uint8, (200, 200), color=(0, 0, 0, 0)) + + with pytest.raises(ValueError, match="color length"): + # BGRA (4-ch) frame with 3-ch padding + frame_bgra = create_bgra_frame(np.uint8) + letterbox(frame_bgra, (200, 200), color=(0, 0, 0)) diff --git a/tests/arduino/app_utils/image/test_pipeable.py b/tests/arduino/app_utils/image/test_pipeable.py new file mode 100644 index 00000000..29cf5a97 --- /dev/null +++ b/tests/arduino/app_utils/image/test_pipeable.py @@ -0,0 +1,191 @@ +# SPDX-FileCopyrightText: Copyright (C) 2025 ARDUINO SA +# +# SPDX-License-Identifier: MPL-2.0 + +import pytest +from unittest.mock import MagicMock +from arduino.app_utils.image.pipeable import PipeableFunction + + +class TestPipeableFunction: + """Test cases for the PipeableFunction class.""" + + def test_init(self): + """Test PipeableFunction initialization.""" + mock_func = MagicMock() + pf = PipeableFunction(mock_func, 1, 2, kwarg1="value1") + + assert pf.func == mock_func + assert pf.args == (1, 2) + assert pf.kwargs == {"kwarg1": "value1"} + + def test_call_no_existing_args(self): + """Test calling PipeableFunction with no existing args.""" + mock_func = MagicMock(return_value="result") + pf = PipeableFunction(mock_func) + + result = pf(1, 2, kwarg1="value1") + + mock_func.assert_called_once_with(1, 2, kwarg1="value1") + assert result == "result" + + def test_call_with_existing_args(self): + """Test calling PipeableFunction with existing args.""" + mock_func = MagicMock(return_value="result") + pf = PipeableFunction(mock_func, 1, kwarg1="value1") + + result = pf(2, 3, kwarg2="value2") + + mock_func.assert_called_once_with(1, 2, 3, kwarg1="value1", kwarg2="value2") + assert result == "result" + + def test_call_kwargs_override(self): + """Test that new kwargs override existing ones.""" + mock_func = MagicMock(return_value="result") + pf = PipeableFunction(mock_func, kwarg1="old_value") + + result = pf(kwarg1="new_value", kwarg2="value2") + + mock_func.assert_called_once_with(kwarg1="new_value", kwarg2="value2") + assert result == "result" + + def test_ror_pipe_operator(self): + """Test right-hand side pipe operator (value | function).""" + + def add_one(x): + return x + 1 + + pf = PipeableFunction(add_one) + result = 5 | pf + + assert result == 6 + + def test_or_pipe_operator(self): + """Test left-hand side pipe operator (function | function).""" + + def add_one(x): + return x + 1 + + def multiply_two(x): + return x * 2 + + pf1 = PipeableFunction(add_one) + pf2 = PipeableFunction(multiply_two) + + # Chain: add_one | multiply_two + composed = pf1 | pf2 + + assert isinstance(composed, PipeableFunction) + result = composed(5) # (5 + 1) * 2 = 12 + assert result == 12 + + def test_or_pipe_operator_with_non_callable(self): + """Test pipe operator with non-callable returns NotImplemented.""" + pf = PipeableFunction(lambda x: x) + with pytest.raises(TypeError, match="unsupported operand type"): + pf | "not_callable" + + def test_repr_with_function_name(self): + """Test string representation with function having __name__.""" + + def test_func(): + pass + + pf = PipeableFunction(test_func) + assert repr(pf) == "test_func()" + + def test_repr_with_args_and_kwargs(self): + """Test string representation with args and kwargs.""" + + def test_func(): + pass + + pf = PipeableFunction(test_func, 1, 2, kwarg1="value1", kwarg2=42) + repr_str = repr(pf) + + assert "test_func(" in repr_str + assert "1" in repr_str + assert "2" in repr_str + assert "kwarg1=value1" in repr_str + assert "kwarg2=42" in repr_str + + def test_repr_with_partial_object(self): + """Test string representation with functools.partial object.""" + from functools import partial + + def test_func(a, b): + return a + b + + partial_func = partial(test_func, b=10) + pf = PipeableFunction(partial_func) + + repr_str = repr(pf) + assert "test_func" in repr_str or "partial" in repr_str + + def test_repr_with_callable_without_name(self): + """Test string representation with callable without __name__.""" + + class CallableClass: + def __call__(self): + pass + + callable_obj = CallableClass() + pf = PipeableFunction(callable_obj) + + repr_str = repr(pf) + assert "CallableClass" in repr_str + + +class TestPipeableFunctionIntegration: + """Integration tests for the PipeableFunction class.""" + + def test_real_world_data_processing(self): + """Test pipeable with real-world data processing scenario.""" + + def filter_positive(numbers): + return [n for n in numbers if n > 0] + + def filtered_positive(): + return PipeableFunction(filter_positive) + + def square_all(numbers): + return [n * n for n in numbers] + + def squared(): + return PipeableFunction(square_all) + + def sum_all(numbers): + return sum(numbers) + + def summed(): + return PipeableFunction(sum_all) + + data = [-2, -1, 0, 1, 2, 3] + + # Pipeline: filter positive -> square -> sum + # [1, 2, 3] -> [1, 4, 9] -> 14 + result = data | filtered_positive() | squared() | summed() + assert result == 14 + + def test_error_handling_in_pipeline(self): + """Test error handling within pipelines.""" + + def divide_by(x, divisor): + return x / divisor # May raise ZeroDivisionError + + def divided_by(divisor): + return PipeableFunction(divide_by, divisor=divisor) + + def round_number(x, decimals=2): + return round(x, decimals) + + def rounded(decimals=2): + return PipeableFunction(round_number, decimals=decimals) + + # Test successful pipeline + result = 10 | divided_by(3) | rounded(decimals=2) + assert result == 3.33 + + # Test error propagation + with pytest.raises(ZeroDivisionError): + 10 | divided_by(0) | rounded()