diff --git a/snapcast/control/client.py b/snapcast/control/client.py index af14496..f49ead9 100644 --- a/snapcast/control/client.py +++ b/snapcast/control/client.py @@ -1,84 +1,84 @@ """Snapcast client.""" import logging - +from typing import Any, Callable, Dict, List, Optional, Union _LOGGER = logging.getLogger(__name__) # pylint: disable=too-many-public-methods -class Snapclient(): - """Represents a snapclient.""" +class Snapclient: + """Initialize the Client object.""" - def __init__(self, server, data): + def __init__(self, server: Any, data: Dict[str, Any]) -> None: """Initialize.""" self._server = server - self._snapshot = None - self._last_seen = None - self._callback_func = None + self._snapshot: Optional[Dict[str, Union[str, int, bool]]] = None + self._last_seen: Optional[str] = None + self._callback_func: Optional[Callable[[Any], None]] = None self.update(data) - def update(self, data): + def update(self, data: Dict[str, Any]) -> None: """Update client.""" self._client = data @property - def identifier(self): - """Get identifier.""" + def identifier(self) -> Optional[str]: + """Get client identifier.""" return self._client.get('id') @property - def group(self): - """Get group.""" + def group(self) -> Optional[Any]: + """Get group that the client is part of""" for group in self._server.groups: if self.identifier in group.clients: return group return None @property - def friendly_name(self): - """Get friendly name.""" - if len(self._client.get('config').get('name')): - return self._client.get('config').get('name') - return self._client.get('host').get('name') + def friendly_name(self) -> str: + """Get client friendly name.""" + if len(self._client.get('config', {}).get('name', '')): + return self._client.get('config').get('name', '') + return self._client.get('host', {}).get('name', '') @property - def version(self): - """Version.""" - return self._client.get('snapclient').get('version') + def version(self) -> Optional[str]: + """Get client snapclient version.""" + return self._client.get('snapclient', {}).get('version') @property - def connected(self): - """Connected or not.""" - return self._client.get('connected') + def connected(self) -> bool: + """Get the current connection status of the client.""" + return self._client.get('connected', False) @property - def name(self): - """Name.""" - return self._client.get('config').get('name') + def name(self) -> str: + """Get name of the client.""" + return self._client.get('config', {}).get('name', '') - async def set_name(self, name): - """Set a client name.""" + async def set_name(self, name: str) -> None: + """Set a new name for the client.""" if not name: name = '' self._client['config']['name'] = name await self._server.client_name(self.identifier, name) @property - def latency(self): - """Latency.""" - return self._client.get('config').get('latency') + def latency(self) -> Optional[int]: + """Get client latency.""" + return self._client.get('config', {}).get('latency') - async def set_latency(self, latency): + async def set_latency(self, latency: int) -> None: """Set client latency.""" self._client['config']['latency'] = latency await self._server.client_latency(self.identifier, latency) @property - def muted(self): + def muted(self) -> bool: """Muted or not.""" - return self._client.get('config').get('volume').get('muted') + return self._client.get('config', {}).get('volume', {}).get('muted', False) - async def set_muted(self, status): + async def set_muted(self, status: bool) -> None: """Set client mute status.""" new_volume = self._client['config']['volume'] new_volume['muted'] = status @@ -87,11 +87,11 @@ async def set_muted(self, status): _LOGGER.debug('set muted to %s on %s', status, self.friendly_name) @property - def volume(self): - """Volume percent.""" - return self._client.get('config').get('volume').get('percent') + def volume(self) -> int: + """Get client volume percent.""" + return self._client.get('config', {}).get('volume', {}).get('percent', 0) - async def set_volume(self, percent, update_group=True): + async def set_volume(self, percent: int, update_group: bool = True) -> None: """Set client volume percent.""" if percent not in range(0, 101): raise ValueError('Volume percent out of range') @@ -103,37 +103,44 @@ async def set_volume(self, percent, update_group=True): self._server.group(self.group.identifier).callback() _LOGGER.debug('set volume to %s on %s', percent, self.friendly_name) - def groups_available(self): + def groups_available(self) -> List[Any]: """Get available group objects.""" return list(self._server.groups) - def update_volume(self, data): + def update_volume(self, data: Dict[str, Any]) -> None: """Update volume.""" self._client['config']['volume'] = data['volume'] _LOGGER.debug('updated volume on %s', self.friendly_name) self._server.group(self.group.identifier).callback() self.callback() - def update_name(self, data): + def update_name(self, data: Dict[str, Any]) -> None: """Update name.""" self._client['config']['name'] = data['name'] _LOGGER.debug('updated name on %s', self.friendly_name) self.callback() - def update_latency(self, data): + def update_latency(self, data: Dict[str, Any]) -> None: """Update latency.""" self._client['config']['latency'] = data['latency'] _LOGGER.debug('updated latency on %s', self.friendly_name) self.callback() - def update_connected(self, status): + def update_connected(self, status: bool) -> None: """Update connected.""" self._client['connected'] = status _LOGGER.debug('updated connected status to %s on %s', status, self.friendly_name) self.callback() - def snapshot(self): - """Snapshot current state.""" + def snapshot(self) -> None: + """Snapshot current state of the client. + + Snapshot: + - Client name + - Client volume + - Client muting status + - Client latency + """ self._snapshot = { 'name': self.name, 'volume': self.volume, @@ -142,8 +149,14 @@ def snapshot(self): } _LOGGER.debug('took snapshot of current state of %s', self.friendly_name) - async def restore(self): - """Restore snapshotted state.""" + async def restore(self) -> None: + """Restore snapshotted state. + Snapshot: + - Client name + - Client volume + - Client muting status + - Client latency + """ if not self._snapshot: return await self.set_name(self._snapshot['name']) @@ -153,15 +166,15 @@ async def restore(self): self.callback() _LOGGER.debug('restored snapshot of state of %s', self.friendly_name) - def callback(self): - """Run callback.""" + def callback(self) -> None: + """Run callback function if set.""" if self._callback_func and callable(self._callback_func): self._callback_func(self) - def set_callback(self, func): + def set_callback(self, func: Callable[[Any], None]) -> None: """Set callback function.""" self._callback_func = func - def __repr__(self): - """Return string representation.""" + def __repr__(self) -> str: + """Return string representation of the client.""" return f'Snapclient {self.version} ({self.friendly_name}, {self.identifier})' diff --git a/snapcast/control/group.py b/snapcast/control/group.py index 98bbad1..a95b981 100644 --- a/snapcast/control/group.py +++ b/snapcast/control/group.py @@ -1,36 +1,36 @@ """Snapcast group.""" import logging - +from typing import Any, Callable, Dict, List, Optional, Union _LOGGER = logging.getLogger(__name__) - # pylint: disable=too-many-public-methods -class Snapgroup(): + +class Snapgroup: """Represents a snapcast group.""" - def __init__(self, server, data): - """Initialize.""" - self._server = server - self._snapshot = None - self._callback_func = None + def __init__(self, server: Any, data: Dict[str, Any]) -> None: + """Initialize the group object.""" + self._server: Any = server + self._snapshot: Optional[Dict[str, Union[int, str, bool]]] = None + self._callback_func: Optional[Callable[[Any], None]] = None self.update(data) - def update(self, data): - """Update group.""" - self._group = data + def update(self, data: Dict[str, Any]) -> None: + """Update group data.""" + self._group: Dict[str, Any] = data @property - def identifier(self): + def identifier(self) -> str: """Get group identifier.""" - return self._group.get('id') + return self._group.get('id', '') @property - def name(self): + def name(self) -> str: """Get group name.""" - return self._group.get('name') + return self._group.get('name', '') - async def set_name(self, name): + async def set_name(self, name: str) -> None: """Set a group name.""" if not name: name = '' @@ -38,41 +38,41 @@ async def set_name(self, name): await self._server.group_name(self.identifier, name) @property - def stream(self): + def stream(self) -> str: """Get stream identifier.""" - return self._group.get('stream_id') + return self._group.get('stream_id', '') - async def set_stream(self, stream_id): + async def set_stream(self, stream_id: str) -> None: """Set group stream.""" self._group['stream_id'] = stream_id await self._server.group_stream(self.identifier, stream_id) _LOGGER.debug('set stream to %s on %s', stream_id, self.friendly_name) @property - def stream_status(self): + def stream_status(self) -> Any: """Get stream status.""" return self._server.stream(self.stream).status @property - def muted(self): + def muted(self) -> bool: """Get mute status.""" - return self._group.get('muted') + return self._group.get('muted', False) - async def set_muted(self, status): + async def set_muted(self, status: bool) -> None: """Set group mute status.""" self._group['muted'] = status await self._server.group_mute(self.identifier, status) _LOGGER.debug('set muted to %s on %s', status, self.friendly_name) @property - def volume(self): + def volume(self) -> int: """Get volume.""" volume_sum = 0 - for client in self._group.get('clients'): + for client in self._group.get('clients', []): volume_sum += self._server.client(client.get('id')).volume - return int(volume_sum / len(self._group.get('clients'))) + return int(volume_sum / len(self._group.get('clients', []))) - async def set_volume(self, volume): + async def set_volume(self, volume: int) -> None: """Set volume.""" if volume not in range(0, 101): raise ValueError('Volume out of range') @@ -85,7 +85,7 @@ async def set_volume(self, volume): ratio = (current_volume - volume) / current_volume else: ratio = (volume - current_volume) / (100 - current_volume) - for data in self._group.get('clients'): + for data in self._group.get('clients', []): client = self._server.client(data.get('id')) client_volume = client.volume if delta < 0: @@ -103,20 +103,20 @@ async def set_volume(self, volume): _LOGGER.debug('set volume to %s on group %s', volume, self.friendly_name) @property - def friendly_name(self): - """Get friendly name.""" + def friendly_name(self) -> str: + """Get group friendly name.""" fname = self.name if self.name != '' else "+".join( sorted([self._server.client(c).friendly_name for c in self.clients if c in [client.identifier for client in self._server.clients]])) return fname if fname != '' else self.identifier @property - def clients(self): - """Get client identifiers.""" - return [client.get('id') for client in self._group.get('clients')] + def clients(self) -> List[str]: + """Get all the client identifiers for the group.""" + return [client.get('id', '') for client in self._group.get('clients', [])] - async def add_client(self, client_identifier): - """Add a client.""" + async def add_client(self, client_identifier: str) -> None: + """Add a client to the group.""" if client_identifier in self.clients: _LOGGER.error('%s already in group %s', client_identifier, self.identifier) return @@ -129,8 +129,8 @@ async def add_client(self, client_identifier): self._server.client(client_identifier).callback() self.callback() - async def remove_client(self, client_identifier): - """Remove a client.""" + async def remove_client(self, client_identifier: str) -> None: + """Remove a client from the group.""" new_clients = self.clients new_clients.remove(client_identifier) await self._server.group_clients(self.identifier, new_clients) @@ -140,30 +140,37 @@ async def remove_client(self, client_identifier): self._server.client(client_identifier).callback() self.callback() - def streams_by_name(self): + def streams_by_name(self) -> Dict[str, Any]: """Get available stream objects by name.""" return {stream.friendly_name: stream for stream in self._server.streams} - def update_mute(self, data): + def update_mute(self, data: Dict[str, Any]) -> None: """Update mute.""" self._group['muted'] = data['mute'] self.callback() _LOGGER.debug('updated mute on %s', self.friendly_name) - def update_name(self, data): + def update_name(self, data: Dict[str, Any]) -> None: """Update name.""" self._group['name'] = data['name'] _LOGGER.debug('updated name on %s', self.name) self.callback() - def update_stream(self, data): + def update_stream(self, data: Dict[str, Any]) -> None: """Update stream.""" self._group['stream_id'] = data['stream_id'] self.callback() _LOGGER.debug('updated stream to %s on %s', self.stream, self.friendly_name) - def snapshot(self): - """Snapshot current state.""" + def snapshot(self) -> None: + """Snapshot current state. + + Snapshot: + - Group muting status + - Group volume + - Group stream identifier + + """ self._snapshot = { 'muted': self.muted, 'volume': self.volume, @@ -171,8 +178,13 @@ def snapshot(self): } _LOGGER.debug('took snapshot of current state of %s', self.friendly_name) - async def restore(self): - """Restore snapshotted state.""" + async def restore(self) -> None: + """Restore snapshotted state. + Snapshot: + - Group muting status + - Group volume + - Group stream identifier + """ if not self._snapshot: return await self.set_muted(self._snapshot['muted']) @@ -181,15 +193,15 @@ async def restore(self): self.callback() _LOGGER.debug('restored snapshot of state of %s', self.friendly_name) - def callback(self): - """Run callback.""" + def callback(self) -> None: + """Run callback function if set.""" if self._callback_func and callable(self._callback_func): self._callback_func(self) - def set_callback(self, func): - """Set callback.""" + def set_callback(self, func: Callable[[Any], None]) -> None: + """Set callback function.""" self._callback_func = func - def __repr__(self): - """Return string representation.""" + def __repr__(self) -> str: + """Return string representation of the group.""" return f'Snapgroup ({self.friendly_name}, {self.identifier})' diff --git a/snapcast/control/protocol.py b/snapcast/control/protocol.py index 9a21050..dc018f0 100644 --- a/snapcast/control/protocol.py +++ b/snapcast/control/protocol.py @@ -3,12 +3,13 @@ import asyncio import json import random +from typing import Any, Callable, Dict, Optional, Tuple SERVER_ONDISCONNECT = 'Server.OnDisconnect' - # pylint: disable=consider-using-f-string -def jsonrpc_request(method, identifier, params=None): + +def jsonrpc_request(method: str, identifier: int, params: Optional[Dict[str, Any]] = None) -> bytes: """Produce a JSONRPC request.""" return '{}\r\n'.format(json.dumps({ 'id': identifier, @@ -17,29 +18,29 @@ def jsonrpc_request(method, identifier, params=None): 'jsonrpc': '2.0' })).encode() - class SnapcastProtocol(asyncio.Protocol): """Async Snapcast protocol.""" - def __init__(self, callbacks): - """Initialize.""" - self._transport = None - self._buffer = {} - self._callbacks = callbacks - self._data_buffer = '' + def __init__(self, callbacks: Dict[str, Callable[[Any], None]]) -> None: + """Initialize the SnapcastProtocol.""" + self._transport: Optional[asyncio.Transport] = None + self._buffer: Dict[int, Dict[str, Any]] = {} + self._callbacks: Dict[str, Callable[[Any], None]] = callbacks + self._data_buffer: str = '' - def connection_made(self, transport): - """When a connection is made.""" + def connection_made(self, transport: asyncio.Transport) -> None: + """Handle a new connection.""" self._transport = transport - def connection_lost(self, exc): - """When a connection is lost.""" + def connection_lost(self, exc: Optional[Exception]) -> None: + """Handle a lost connection.""" for b in self._buffer.values(): b['error'] = {"code": -1, "message": "connection lost"} b['flag'].set() - self._callbacks.get(SERVER_ONDISCONNECT)(exc) + if SERVER_ONDISCONNECT in self._callbacks: + self._callbacks[SERVER_ONDISCONNECT](exc) - def data_received(self, data): + def data_received(self, data: bytes) -> None: """Handle received data.""" self._data_buffer += data.decode() if not self._data_buffer.endswith('\r\n'): @@ -53,26 +54,27 @@ def data_received(self, data): for item in data: self.handle_data(item) - def handle_data(self, data): + def handle_data(self, data: Dict[str, Any]) -> None: """Handle JSONRPC data.""" if 'id' in data: self.handle_response(data) else: self.handle_notification(data) - def handle_response(self, data): + def handle_response(self, data: Dict[str, Any]) -> None: """Handle JSONRPC response.""" identifier = data.get('id') - self._buffer[identifier]['data'] = data.get('result') - self._buffer[identifier]['error'] = data.get('error') - self._buffer[identifier]['flag'].set() + if identifier in self._buffer: + self._buffer[identifier]['data'] = data.get('result') + self._buffer[identifier]['error'] = data.get('error') + self._buffer[identifier]['flag'].set() - def handle_notification(self, data): + def handle_notification(self, data: Dict[str, Any]) -> None: """Handle JSONRPC notification.""" if data.get('method') in self._callbacks: self._callbacks.get(data.get('method'))(data.get('params')) - async def request(self, method, params): + async def request(self, method: str, params: Optional[Dict[str, Any]] = None) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Send a JSONRPC request.""" identifier = random.randint(1, 1000) self._transport.write(jsonrpc_request(method, identifier, params)) diff --git a/snapcast/control/server.py b/snapcast/control/server.py index 7bd6373..c8518d7 100644 --- a/snapcast/control/server.py +++ b/snapcast/control/server.py @@ -8,6 +8,7 @@ from snapcast.control.group import Snapgroup from snapcast.control.protocol import SERVER_ONDISCONNECT, SnapcastProtocol from snapcast.control.stream import Snapstream +from typing import Any, Callable, Dict, List, Optional, Tuple, Union _LOGGER = logging.getLogger(__name__) @@ -74,24 +75,25 @@ class ServerVersionError(NotImplementedError): # pylint: disable=too-many-public-methods -class Snapserver(): + +class Snapserver: """Represents a snapserver.""" # pylint: disable=too-many-instance-attributes - def __init__(self, loop, host, port=CONTROL_PORT, reconnect=False): + def __init__(self, loop: asyncio.AbstractEventLoop, host: str, port: int = CONTROL_PORT, reconnect: bool = False) -> None: """Initialize.""" - self._loop = loop - self._port = port - self._reconnect = reconnect - self._is_stopped = True - self._clients = {} - self._streams = {} - self._groups = {} - self._host = host - self._version = None - self._protocol = None - self._transport = None - self._callbacks = { + self._loop: asyncio.AbstractEventLoop = loop + self._port: int = port + self._reconnect: bool = reconnect + self._is_stopped: bool = True + self._clients: Dict[str, Any] = {} + self._streams: Dict[str, Any] = {} + self._groups: Dict[str, Any] = {} + self._host: str = host + self._version: Optional[str] = None + self._protocol: Optional[Any] = None + self._transport: Optional[asyncio.Transport] = None + self._callbacks: Dict[str, Callable[[Any], None]] = { CLIENT_ONCONNECT: self._on_client_connect, CLIENT_ONDISCONNECT: self._on_client_disconnect, CLIENT_ONVOLUMECHANGED: self._on_client_volume_changed, @@ -106,12 +108,12 @@ def __init__(self, loop, host, port=CONTROL_PORT, reconnect=False): SERVER_ONDISCONNECT: self._on_server_disconnect, SERVER_ONUPDATE: self._on_server_update } - self._on_update_callback_func = None - self._on_connect_callback_func = None - self._on_disconnect_callback_func = None - self._new_client_callback_func = None + self._on_update_callback_func: Optional[Callable[[], None]] = None + self._on_connect_callback_func: Optional[Callable[[], None]] = None + self._on_disconnect_callback_func: Optional[Callable[[Optional[Exception]], None]] = None + self._new_client_callback_func: Optional[Callable[[Any], None]] = None - async def start(self): + async def start(self) -> None: """Initiate server connection.""" self._is_stopped = False await self._do_connect() @@ -124,8 +126,8 @@ async def start(self): self.synchronize(status) self._on_server_connect() - def stop(self): - """Stop server.""" + def stop(self) -> None: + """Stop server connection.""" self._is_stopped = True self._do_disconnect() _LOGGER.debug('Stopping') @@ -134,22 +136,26 @@ def stop(self): self._groups = {} self._version = None - def _do_disconnect(self): + def _do_disconnect(self) -> None: """Perform the connection to the server.""" if self._transport: self._transport.close() - async def _do_connect(self): + async def _do_connect(self) -> None: """Perform the connection to the server.""" self._transport, self._protocol = await self._loop.create_connection( lambda: SnapcastProtocol(self._callbacks), self._host, self._port) - def _reconnect_cb(self): + def _reconnect_cb(self) -> None: """Try to reconnect to the server.""" _LOGGER.debug('try reconnect') - async def try_reconnect(): - """Actual coroutine ro try to reconnect or reschedule.""" + async def try_reconnect() -> None: + """Actual coroutine to try to reconnect or reschedule. + + Raises: + OSError: If there isn't a valid response from the server. + """ try: await self._do_connect() status, error = await self.status() @@ -158,89 +164,87 @@ async def try_reconnect(): self.stop() raise OSError except OSError: - self._loop.call_later(SERVER_RECONNECT_DELAY, - self._reconnect_cb) + self._loop.call_later(SERVER_RECONNECT_DELAY, self._reconnect_cb) else: self.synchronize(status) self._on_server_connect() asyncio.ensure_future(try_reconnect()) - async def _transact(self, method, params=None): + async def _transact(self, method: str, params: Optional[Dict[str, Any]] = None) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Wrap requests.""" result = error = None if self._protocol is None or self._transport is None or self._transport.is_closing(): error = {"code": None, "message": "Server not connected"} else: result, error = await self._protocol.request(method, params) - return (result, error) + return result, error @property - def version(self): - """Version.""" + def version(self) -> Optional[str]: + """Get server version.""" return self._version - async def status(self): - """System status.""" + async def status(self) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: + """Get system status.""" return await self._transact(SERVER_GETSTATUS) - async def rpc_version(self): - """RPC version.""" + async def rpc_version(self) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: + """Get RPC version.""" return await self._transact(SERVER_GETRPCVERSION) - async def delete_client(self, identifier): - """Delete client.""" + async def delete_client(self, identifier: str) -> None: + """Delete client from the server.""" params = {'id': identifier} response, _ = await self._transact(SERVER_DELETECLIENT, params) self.synchronize(response) - async def client_name(self, identifier, name): + async def client_name(self, identifier: str, name: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set client name.""" return await self._request(CLIENT_SETNAME, identifier, 'name', name) - async def client_latency(self, identifier, latency): + async def client_latency(self, identifier: str, latency: int) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set client latency.""" return await self._request(CLIENT_SETLATENCY, identifier, 'latency', latency) - async def client_volume(self, identifier, volume): + async def client_volume(self, identifier: str, volume: int) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set client volume.""" return await self._request(CLIENT_SETVOLUME, identifier, 'volume', volume) - async def client_status(self, identifier): + async def client_status(self, identifier: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Get client status.""" return await self._request(CLIENT_GETSTATUS, identifier, 'client') - async def group_status(self, identifier): + async def group_status(self, identifier: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Get group status.""" return await self._request(GROUP_GETSTATUS, identifier, 'group') - async def group_mute(self, identifier, status): + async def group_mute(self, identifier: str, status: bool) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set group mute.""" return await self._request(GROUP_SETMUTE, identifier, 'mute', status) - async def group_stream(self, identifier, stream_id): + async def group_stream(self, identifier: str, stream_id: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set group stream.""" return await self._request(GROUP_SETSTREAM, identifier, 'stream_id', stream_id) - async def group_clients(self, identifier, clients): + async def group_clients(self, identifier: str, clients: List[str]) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set group clients.""" return await self._request(GROUP_SETCLIENTS, identifier, 'clients', clients) - async def group_name(self, identifier, name): + async def group_name(self, identifier: str, name: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set group name.""" self._version_check(GROUP_SETNAME) return await self._request(GROUP_SETNAME, identifier, 'name', name) - async def stream_control(self, identifier, control_command, control_params): + async def stream_control(self, identifier: str, control_command: str, control_params: Dict[str, Any]) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set stream control.""" self._version_check(STREAM_SETPROPERTY) - return await self._request( - STREAM_CONTROL, identifier, 'command', control_command, control_params) + return await self._request(STREAM_CONTROL, identifier, 'command', control_command, control_params) - async def stream_setmeta(self, identifier, meta): # deprecated + async def stream_setmeta(self, identifier: str, meta: Dict[str, Any]) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: # deprecated """Set stream metadata.""" return await self._request(STREAM_SETMETA, identifier, 'meta', meta) - async def stream_setproperty(self, identifier, stream_property, value): + async def stream_setproperty(self, identifier: str, stream_property: str, value: Any) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Set stream metadata.""" self._version_check(STREAM_SETPROPERTY) return await self._request(STREAM_SETPROPERTY, identifier, parameters={ @@ -248,68 +252,68 @@ async def stream_setproperty(self, identifier, stream_property, value): 'value': value }) - async def stream_add_stream(self, stream_uri): + async def stream_add_stream(self, stream_uri: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Add a stream.""" params = {"streamUri": stream_uri} result, error = await self._transact(STREAM_ADDSTREAM, params) - if (isinstance(result, dict) and ("id" in result)): + if isinstance(result, dict) and ("id" in result): self.synchronize((await self.status())[0]) return result or error - async def stream_remove_stream(self, identifier): - """Remove a Stream.""" + async def stream_remove_stream(self, identifier: str) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: + """Remove a Stream from the server.""" result = await self._request(STREAM_REMOVESTREAM, identifier) - if (isinstance(result, dict) and ("id" in result)): + if isinstance(result, dict) and ("id" in result): self.synchronize((await self.status())[0]) return result - def group(self, group_identifier): + def group(self, group_identifier: str) -> Any: """Get a group.""" return self._groups[group_identifier] - def stream(self, stream_identifier): + def stream(self, stream_identifier: str) -> Any: """Get a stream.""" return self._streams[stream_identifier] - def client(self, client_identifier): + def client(self, client_identifier: str) -> Any: """Get a client.""" return self._clients[client_identifier] @property - def groups(self): + def groups(self) -> List[Any]: """Get groups.""" return list(self._groups.values()) @property - def clients(self): + def clients(self) -> List[Any]: """Get clients.""" return list(self._clients.values()) @property - def streams(self): + def streams(self) -> List[Any]: """Get streams.""" return list(self._streams.values()) - def synchronize(self, status): + def synchronize(self, status: Dict[str, Any]) -> None: """Synchronize snapserver.""" self._version = status['server']['server']['snapserver']['version'] - new_groups = {} - new_clients = {} - new_streams = {} - for stream in status.get('server').get('streams'): + new_groups: Dict[str, Any] = {} + new_clients: Dict[str, Any] = {} + new_streams: Dict[str, Any] = {} + for stream in status.get('server', {}).get('streams', []): if stream.get('id') in self._streams: new_streams[stream.get('id')] = self._streams[stream.get('id')] new_streams[stream.get('id')].update(stream) else: new_streams[stream.get('id')] = Snapstream(stream) _LOGGER.debug('stream found: %s', new_streams[stream.get('id')]) - for group in status.get('server').get('groups'): + for group in status.get('server', {}).get('groups', []): if group.get('id') in self._groups: new_groups[group.get('id')] = self._groups[group.get('id')] new_groups[group.get('id')].update(group) else: new_groups[group.get('id')] = Snapgroup(self, group) - for client in group.get('clients'): + for client in group.get('clients', []): if client.get('id') in self._clients: new_clients[client.get('id')] = self._clients[client.get('id')] new_clients[client.get('id')].update(client) @@ -322,7 +326,7 @@ def synchronize(self, status): self._streams = new_streams # pylint: disable=too-many-arguments - async def _request(self, method, identifier, key=None, value=None, parameters=None): + async def _request(self, method: str, identifier: str, key: Optional[str] = None, value: Optional[Any] = None, parameters: Optional[Dict[str, Any]] = None) -> Tuple[Optional[Any], Optional[Dict[str, Any]]]: """Perform request with identifier.""" params = {'id': identifier} if key is not None and value is not None: @@ -331,16 +335,16 @@ async def _request(self, method, identifier, key=None, value=None, parameters=No params.update(parameters) result, error = await self._transact(method, params) if isinstance(result, dict) and key in result: - return result.get(key) - return result or error + return result.get(key), None + return result, error - def _on_server_connect(self): + def _on_server_connect(self) -> None: """Handle server connection.""" _LOGGER.debug('Server connected') if self._on_connect_callback_func and callable(self._on_connect_callback_func): self._on_connect_callback_func() - def _on_server_disconnect(self, exception): + def _on_server_disconnect(self, exception: Optional[Exception]) -> None: """Handle server disconnection.""" _LOGGER.debug('Server disconnected: %s', str(exception)) if self._on_disconnect_callback_func and callable(self._on_disconnect_callback_func): @@ -350,31 +354,34 @@ def _on_server_disconnect(self, exception): if (not self._is_stopped) and self._reconnect: self._reconnect_cb() - def _on_server_update(self, data): + def _on_server_update(self, data: Dict[str, Any]) -> None: """Handle server update.""" self.synchronize(data) if self._on_update_callback_func and callable(self._on_update_callback_func): self._on_update_callback_func() - def _on_group_mute(self, data): + def _on_group_mute(self, data: Dict[str, Any]) -> None: """Handle group mute.""" group = self._groups.get(data.get('id')) - group.update_mute(data) - for client_id in group.clients: - self._clients.get(client_id).callback() + if group: + group.update_mute(data) + for client_id in group.clients: + self._clients.get(client_id).callback() - def _on_group_name_changed(self, data): + def _on_group_name_changed(self, data: Dict[str, Any]) -> None: """Handle group name changed.""" - self._groups.get(data.get('id')).update_name(data) + if data.get('id') in self._groups: + self._groups[data.get('id')].update_name(data) - def _on_group_stream_changed(self, data): + def _on_group_stream_changed(self, data: Dict[str, Any]) -> None: """Handle group stream change.""" group = self._groups.get(data.get('id')) - group.update_stream(data) - for client_id in group.clients: - self._clients.get(client_id).callback() + if group: + group.update_stream(data) + for client_id in group.clients: + self._clients.get(client_id).callback() - def _on_client_connect(self, data): + def _on_client_connect(self, data: Dict[str, Any]) -> None: """Handle client connect.""" client = None if data.get('id') in self._clients: @@ -387,24 +394,28 @@ def _on_client_connect(self, data): self._new_client_callback_func(client) _LOGGER.debug('client %s connected', client.friendly_name) - def _on_client_disconnect(self, data): + def _on_client_disconnect(self, data: Dict[str, Any]) -> None: """Handle client disconnect.""" - self._clients[data.get('id')].update_connected(False) - _LOGGER.debug('client %s disconnected', self._clients[data.get('id')].friendly_name) + if data.get('id') in self._clients: + self._clients[data.get('id')].update_connected(False) + _LOGGER.debug('client %s disconnected', self._clients[data.get('id')].friendly_name) - def _on_client_volume_changed(self, data): + def _on_client_volume_changed(self, data: Dict[str, Any]) -> None: """Handle client volume change.""" - self._clients.get(data.get('id')).update_volume(data) + if data.get('id') in self._clients: + self._clients.get(data.get('id')).update_volume(data) - def _on_client_name_changed(self, data): + def _on_client_name_changed(self, data: Dict[str, Any]) -> None: """Handle client name changed.""" - self._clients.get(data.get('id')).update_name(data) + if data.get('id') in self._clients: + self._clients.get(data.get('id')).update_name(data) - def _on_client_latency_changed(self, data): + def _on_client_latency_changed(self, data: Dict[str, Any]) -> None: """Handle client latency changed.""" - self._clients.get(data.get('id')).update_latency(data) + if data.get('id') in self._clients: + self._clients.get(data.get('id')).update_latency(data) - def _on_stream_meta(self, data): # deprecated + def _on_stream_meta(self, data: Dict[str, Any]) -> None: # deprecated """Handle stream metadata update.""" if stream := self._streams.get(data.get('id')): stream.update_meta(data.get('meta')) @@ -413,7 +424,7 @@ def _on_stream_meta(self, data): # deprecated if group.stream == data.get('id'): group.callback() - def _on_stream_properties(self, data): + def _on_stream_properties(self, data: Dict[str, Any]) -> None: """Handle stream properties update.""" if stream := self._streams.get(data.get('id')): stream.update_properties(data.get('properties')) @@ -424,7 +435,7 @@ def _on_stream_properties(self, data): for client_id in group.clients: self._clients.get(client_id).callback() - def _on_stream_update(self, data): + def _on_stream_update(self, data: Dict[str, Any]) -> None: """Handle stream update.""" if data.get('id') in self._streams: self._streams[data.get('id')].update(data.get('stream')) @@ -441,31 +452,37 @@ def _on_stream_update(self, data): else: _LOGGER.info('stream %s not found, synchronize', data.get('id')) - async def async_sync(): + async def async_sync() -> None: self.synchronize((await self.status())[0]) asyncio.ensure_future(async_sync()) - def set_on_update_callback(self, func): + def set_on_update_callback(self, func: Callable[[], None]) -> None: """Set on update callback function.""" self._on_update_callback_func = func - def set_on_connect_callback(self, func): + def set_on_connect_callback(self, func: Callable[[], None]) -> None: """Set on connection callback function.""" self._on_connect_callback_func = func - def set_on_disconnect_callback(self, func): + def set_on_disconnect_callback(self, func: Callable[[Optional[Exception]], None]) -> None: """Set on disconnection callback function.""" self._on_disconnect_callback_func = func - def set_new_client_callback(self, func): + def set_new_client_callback(self, func: Callable[[Any], None]) -> None: """Set new client callback function.""" self._new_client_callback_func = func - def __repr__(self): - """Return string representation.""" + def __repr__(self) -> str: + """Return string representation of the server.""" return f'Snapserver {self.version} ({self._host})' - def _version_check(self, api_call): + def _version_check(self, api_call: str) -> None: + """ + Checks if the server version meets the minimum requirement for a given API call. + + Raises: + ServerVersionError: If the server version is lower than the required version for the API call. + """ if version.parse(self.version) < version.parse(_VERSIONS.get(api_call)): raise ServerVersionError( f"{api_call} requires server version >= {_VERSIONS[api_call]}." diff --git a/snapcast/control/stream.py b/snapcast/control/stream.py index d9a663e..6cac892 100644 --- a/snapcast/control/stream.py +++ b/snapcast/control/stream.py @@ -1,83 +1,84 @@ """Snapcast stream.""" +from typing import Any, Callable, Optional -class Snapstream(): +class Snapstream: """Represents a snapcast stream.""" - def __init__(self, data): - """Initialize.""" + def __init__(self, data: dict) -> None: + """Initialize the Stream object.""" self.update(data) - self._callback_func = None + self._callback_func: Optional[Callable[['Snapstream'], None]] = None @property - def identifier(self): + def identifier(self) -> str: """Get stream id.""" return self._stream.get('id') @property - def status(self): + def status(self) -> Any: """Get stream status.""" return self._stream.get('status') @property - def name(self): + def name(self) -> str: """Get stream name.""" - return self._stream.get('uri').get('query').get('name') + return self._stream.get('uri', {}).get('query', {}).get('name', '') @property - def friendly_name(self): + def friendly_name(self) -> str: """Get friendly name.""" return self.name if self.name != '' else self.identifier @property - def metadata(self): + def metadata(self) -> Optional[dict]: """Get metadata.""" if 'properties' in self._stream: return self._stream['properties'].get('metadata') return self._stream.get('meta') @property - def meta(self): + def meta(self) -> Optional[dict]: """Get metadata. Deprecated.""" return self.metadata @property - def properties(self): + def properties(self) -> Optional[dict]: """Get properties.""" return self._stream.get('properties') @property - def path(self): + def path(self) -> str: """Get stream path.""" - return self._stream.get('uri').get('path') + return self._stream.get('uri', {}).get('path', '') - def update(self, data): + def update(self, data: dict) -> None: """Update stream.""" self._stream = data - def update_meta(self, data): + def update_meta(self, data: dict) -> None: """Update stream metadata.""" self.update_metadata(data) - def update_metadata(self, data): + def update_metadata(self, data: dict) -> None: """Update stream metadata.""" if 'properties' in self._stream: self._stream['properties']['metadata'] = data self._stream['meta'] = data - def update_properties(self, data): + def update_properties(self, data: dict) -> None: """Update stream properties.""" self._stream['properties'] = data - def __repr__(self): + def __repr__(self) -> str: """Return string representation.""" return f'Snapstream ({self.name})' - def callback(self): - """Run callback.""" + def callback(self) -> None: + """Run callback if set.""" if self._callback_func and callable(self._callback_func): self._callback_func(self) - def set_callback(self, func): - """Set callback.""" + def set_callback(self, func: Callable[['Snapstream'], None]) -> None: + """Set callback function.""" self._callback_func = func