diff --git a/README.md b/README.md index ace817d..dca9863 100644 --- a/README.md +++ b/README.md @@ -3,8 +3,9 @@ # python-snapcast Control [Snapcast](https://github.com/badaix/snapcast) in Python 3. Reads client configurations, updates clients, and receives updates from other controllers. +The connection could be made with the json-rpc or Websockets interface. Websockets is more stable due to [issue](https://github.com/badaix/snapcast/issues/1173) in snapserver. -Supports Snapcast `0.15.0`. +Supports Snapcast `0.15.0`, but works well with latest Snapcast `0.27.0` ## Install @@ -18,7 +19,7 @@ import asyncio import snapcast.control loop = asyncio.get_event_loop() -server = loop.run_until_complete(snapcast.control.create_server(loop, 'localhost')) +server = loop.run_until_complete(snapcast.control.create_server(loop, 'localhost', port=1780, reconnect=True, use_websockets=True)) # print all client names for client in server.clients: diff --git a/setup.cfg b/setup.cfg index b88034e..08aedd7 100644 --- a/setup.cfg +++ b/setup.cfg @@ -1,2 +1,2 @@ [metadata] -description-file = README.md +description_file = README.md diff --git a/setup.py b/setup.py index 6c934d0..1fc14f5 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name='snapcast', - version='2.3.3', + version='2.3.6', description='Control Snapcast.', url='https://github.com/happyleavesaoc/python-snapcast/', license='MIT', @@ -12,6 +12,7 @@ install_requires=[ 'construct>=2.5.2', 'packaging', + 'websockets', ], classifiers=[ 'License :: OSI Approved :: MIT License', diff --git a/snapcast/control/__init__.py b/snapcast/control/__init__.py index ec6ad3e..927262d 100644 --- a/snapcast/control/__init__.py +++ b/snapcast/control/__init__.py @@ -3,8 +3,8 @@ from snapcast.control.server import Snapserver, CONTROL_PORT -async def create_server(loop, host, port=CONTROL_PORT, reconnect=False): +async def create_server(loop, host, port=CONTROL_PORT, reconnect=False, use_websockets=False): """Server factory.""" - server = Snapserver(loop, host, port, reconnect) + server = Snapserver(loop, host, port, reconnect, use_websockets) await server.start() return server diff --git a/snapcast/control/client.py b/snapcast/control/client.py index 9b850e6..af14496 100644 --- a/snapcast/control/client.py +++ b/snapcast/control/client.py @@ -32,6 +32,7 @@ def group(self): for group in self._server.groups: if self.identifier in group.clients: return group + return None @property def friendly_name(self): @@ -162,5 +163,5 @@ def set_callback(self, func): self._callback_func = func def __repr__(self): - """String representation.""" + """Return string representation.""" return f'Snapclient {self.version} ({self.friendly_name}, {self.identifier})' diff --git a/snapcast/control/group.py b/snapcast/control/group.py index 7935b2f..98bbad1 100644 --- a/snapcast/control/group.py +++ b/snapcast/control/group.py @@ -105,8 +105,10 @@ async def set_volume(self, volume): @property def friendly_name(self): """Get friendly name.""" - return self.name if self.name != '' else "+".join( - sorted([self._server.client(c).friendly_name for c in self.clients])) + fname = self.name if self.name != '' else "+".join( + sorted([self._server.client(c).friendly_name for c in self.clients + if c in [client.identifier for client in self._server.clients]])) + return fname if fname != '' else self.identifier @property def clients(self): @@ -122,7 +124,7 @@ async def add_client(self, client_identifier): new_clients.append(client_identifier) await self._server.group_clients(self.identifier, new_clients) _LOGGER.debug('added %s to %s', client_identifier, self.identifier) - status = await self._server.status() + status = (await self._server.status())[0] self._server.synchronize(status) self._server.client(client_identifier).callback() self.callback() @@ -133,7 +135,7 @@ async def remove_client(self, client_identifier): new_clients.remove(client_identifier) await self._server.group_clients(self.identifier, new_clients) _LOGGER.debug('removed %s from %s', client_identifier, self.identifier) - status = await self._server.status() + status = (await self._server.status())[0] self._server.synchronize(status) self._server.client(client_identifier).callback() self.callback() @@ -189,5 +191,5 @@ def set_callback(self, func): self._callback_func = func def __repr__(self): - """String representation.""" + """Return string representation.""" return f'Snapgroup ({self.friendly_name}, {self.identifier})' diff --git a/snapcast/control/protocol.py b/snapcast/control/protocol.py index d6df646..9a21050 100644 --- a/snapcast/control/protocol.py +++ b/snapcast/control/protocol.py @@ -7,6 +7,7 @@ SERVER_ONDISCONNECT = 'Server.OnDisconnect' +# pylint: disable=consider-using-f-string def jsonrpc_request(method, identifier, params=None): """Produce a JSONRPC request.""" return '{}\r\n'.format(json.dumps({ @@ -33,6 +34,9 @@ def connection_made(self, transport): def connection_lost(self, exc): """When a connection is lost.""" + for b in self._buffer.values(): + b['error'] = {"code": -1, "message": "connection lost"} + b['flag'].set() self._callbacks.get(SERVER_ONDISCONNECT)(exc) def data_received(self, data): @@ -74,8 +78,8 @@ async def request(self, method, params): self._transport.write(jsonrpc_request(method, identifier, params)) self._buffer[identifier] = {'flag': asyncio.Event()} await self._buffer[identifier]['flag'].wait() - result = self._buffer[identifier]['data'] - error = self._buffer[identifier]['error'] - del self._buffer[identifier]['data'] - del self._buffer[identifier]['error'] + result = self._buffer[identifier].get('data') + error = self._buffer[identifier].get('error') + self._buffer[identifier].clear() + del self._buffer[identifier] return (result, error) diff --git a/snapcast/control/server.py b/snapcast/control/server.py index e93f5b1..af7a8e6 100644 --- a/snapcast/control/server.py +++ b/snapcast/control/server.py @@ -2,16 +2,19 @@ import asyncio import logging +import websockets from packaging import version from snapcast.control.client import Snapclient from snapcast.control.group import Snapgroup from snapcast.control.protocol import SERVER_ONDISCONNECT, SnapcastProtocol +from snapcast.control.wsprotocol import SnapcastWebSocketProtocol from snapcast.control.stream import Snapstream _LOGGER = logging.getLogger(__name__) CONTROL_PORT = 1705 +WEBSOCKET_PORT = 1780 SERVER_GETSTATUS = 'Server.GetStatus' SERVER_GETRPCVERSION = 'Server.GetRPCVersion' @@ -44,6 +47,8 @@ STREAM_SETMETA = 'Stream.SetMeta' # deprecated STREAM_ONUPDATE = 'Stream.OnUpdate' STREAM_ONMETA = 'Stream.OnMetadata' # deprecated +STREAM_ADDSTREAM = 'Stream.AddStream' +STREAM_REMOVESTREAM = 'Stream.RemoveStream' SERVER_RECONNECT_DELAY = 5 @@ -55,16 +60,21 @@ SERVER_DELETECLIENT, CLIENT_GETSTATUS, CLIENT_SETNAME, CLIENT_SETLATENCY, CLIENT_SETVOLUME, GROUP_GETSTATUS, GROUP_SETMUTE, GROUP_SETSTREAM, GROUP_SETCLIENTS, - GROUP_SETNAME, STREAM_SETMETA, STREAM_SETPROPERTY, STREAM_CONTROL] + GROUP_SETNAME, STREAM_SETMETA, STREAM_SETPROPERTY, STREAM_CONTROL, + STREAM_ADDSTREAM, STREAM_REMOVESTREAM] # server versions in which new methods were added _VERSIONS = { GROUP_SETNAME: '0.16.0', STREAM_SETPROPERTY: '0.26.0', + STREAM_ADDSTREAM: '0.16.0', + STREAM_REMOVESTREAM: '0.16.0', } class ServerVersionError(NotImplementedError): + """Server Version Error, not implemented.""" + pass @@ -72,10 +82,11 @@ class ServerVersionError(NotImplementedError): class Snapserver(): """Represents a snapserver.""" - # pylint: disable=too-many-instance-attributes - def __init__(self, loop, host, port=CONTROL_PORT, reconnect=False): + # pylint: disable=too-many-instance-attributes,too-many-arguments + def __init__(self, loop, host, port=CONTROL_PORT, reconnect=False, use_websockets=False): """Initialize.""" self._loop = loop + self._use_websockets = use_websockets self._port = port self._reconnect = reconnect self._is_stopped = True @@ -86,6 +97,7 @@ def __init__(self, loop, host, port=CONTROL_PORT, reconnect=False): self._version = None self._protocol = None self._transport = None + self._websocket = None self._callbacks = { CLIENT_ONCONNECT: self._on_client_connect, CLIENT_ONDISCONNECT: self._on_client_disconnect, @@ -110,46 +122,76 @@ async def start(self): """Initiate server connection.""" self._is_stopped = False await self._do_connect() + status, error = await self.status() + if (not isinstance(status, dict)) or ('server' not in status): + _LOGGER.warning('connected, but no valid response:\n%s', str(error)) + self.stop() + raise OSError _LOGGER.debug('connected to snapserver on %s:%s', self._host, self._port) - status = await self.status() self.synchronize(status) self._on_server_connect() - async def stop(self): + def stop(self): """Stop server.""" self._is_stopped = True self._do_disconnect() - _LOGGER.debug('disconnected from snapserver on %s:%s', self._host, self._port) + _LOGGER.debug('Stopping') self._clients = {} self._streams = {} self._groups = {} self._version = None - self._protocol = None - self._transport = None def _do_disconnect(self): - """Perform the connection to the server.""" + """Disconnect from server.""" if self._transport: self._transport.close() async def _do_connect(self): """Perform the connection to the server.""" - self._transport, self._protocol = await self._loop.create_connection( - lambda: SnapcastProtocol(self._callbacks), self._host, self._port) + connected = asyncio.Event() + + # actual corutine to handle websocket connection + async def websocket_handler(): + _LOGGER.debug('try connect to websocket') + async for self._websocket in websockets.connect( + uri=f"ws://{self._host}:{self._port}/jsonrpc"): + self._protocol = SnapcastWebSocketProtocol(self._websocket, self._callbacks) + connected.set() + try: + # Receives the replies. + async for message in self._websocket: + self._protocol.message_received(message) + except websockets.ConnectionClosed: + if self._reconnect and not self._is_stopped: + _LOGGER.debug('try reconnect to websocket') + continue + # Closes the connection. + await self._websocket.close() + + if self._use_websockets: + self._loop.create_task(websocket_handler()) + await connected.wait() + else: + self._transport, self._protocol = await self._loop.create_connection( + lambda: SnapcastProtocol(self._callbacks), self._host, self._port) def _reconnect_cb(self): - """Callback to reconnect to the server.""" + """Try to reconnect to the server.""" _LOGGER.debug('try reconnect') async def try_reconnect(): """Actual coroutine ro try to reconnect or reschedule.""" try: await self._do_connect() + status, error = await self.status() + if (not isinstance(status, dict)) or ('server' not in status): + _LOGGER.warning('connected, but no valid response:\n%s', str(error)) + self.stop() + raise OSError except OSError: self._loop.call_later(SERVER_RECONNECT_DELAY, self._reconnect_cb) else: - status = await self.status() self.synchronize(status) self._on_server_connect() asyncio.ensure_future(try_reconnect()) @@ -157,12 +199,11 @@ async def try_reconnect(): async def _transact(self, method, params=None): """Wrap requests.""" result = error = None - try: + if self._protocol is None or self._transport is None or self._transport.is_closing(): + error = {"code": None, "message": "Server not connected"} + else: result, error = await self._protocol.request(method, params) - except: - _LOGGER.warning('could not send request') - error = 'could not send request' - return result or error + return (result, error) @property def version(self): @@ -171,73 +212,88 @@ def version(self): async def status(self): """System status.""" - result = await self._transact(SERVER_GETSTATUS) - return result + return await self._transact(SERVER_GETSTATUS) - def rpc_version(self): + async def rpc_version(self): """RPC version.""" - return self._transact(SERVER_GETRPCVERSION) + return await self._transact(SERVER_GETRPCVERSION) async def delete_client(self, identifier): """Delete client.""" params = {'id': identifier} - response = await self._transact(SERVER_DELETECLIENT, params) + response, _ = await self._transact(SERVER_DELETECLIENT, params) self.synchronize(response) - def client_name(self, identifier, name): + async def client_name(self, identifier, name): """Set client name.""" - return self._request(CLIENT_SETNAME, identifier, 'name', name) + return await self._request(CLIENT_SETNAME, identifier, 'name', name) - def client_latency(self, identifier, latency): + async def client_latency(self, identifier, latency): """Set client latency.""" - return self._request(CLIENT_SETLATENCY, identifier, 'latency', latency) + return await self._request(CLIENT_SETLATENCY, identifier, 'latency', latency) - def client_volume(self, identifier, volume): + async def client_volume(self, identifier, volume): """Set client volume.""" - return self._request(CLIENT_SETVOLUME, identifier, 'volume', volume) + return await self._request(CLIENT_SETVOLUME, identifier, 'volume', volume) - def client_status(self, identifier): + async def client_status(self, identifier): """Get client status.""" - return self._request(CLIENT_GETSTATUS, identifier, 'client') + return await self._request(CLIENT_GETSTATUS, identifier, 'client') - def group_status(self, identifier): + async def group_status(self, identifier): """Get group status.""" - return self._request(GROUP_GETSTATUS, identifier, 'group') + return await self._request(GROUP_GETSTATUS, identifier, 'group') - def group_mute(self, identifier, status): + async def group_mute(self, identifier, status): """Set group mute.""" - return self._request(GROUP_SETMUTE, identifier, 'mute', status) + return await self._request(GROUP_SETMUTE, identifier, 'mute', status) - def group_stream(self, identifier, stream_id): + async def group_stream(self, identifier, stream_id): """Set group stream.""" - return self._request(GROUP_SETSTREAM, identifier, 'stream_id', stream_id) + return await self._request(GROUP_SETSTREAM, identifier, 'stream_id', stream_id) - def group_clients(self, identifier, clients): + async def group_clients(self, identifier, clients): """Set group clients.""" - return self._request(GROUP_SETCLIENTS, identifier, 'clients', clients) + return await self._request(GROUP_SETCLIENTS, identifier, 'clients', clients) - def group_name(self, identifier, name): + async def group_name(self, identifier, name): """Set group name.""" self._version_check(GROUP_SETNAME) - return self._request(GROUP_SETNAME, identifier, 'name', name) + return await self._request(GROUP_SETNAME, identifier, 'name', name) - def stream_control(self, identifier, control_command, control_params): + async def stream_control(self, identifier, control_command, control_params): """Set stream control.""" self._version_check(STREAM_SETPROPERTY) - return self._request(STREAM_CONTROL, identifier, 'command', control_command, control_params) + return await self._request( + STREAM_CONTROL, identifier, 'command', control_command, control_params) - def stream_setmeta(self, identifier, meta): # deprecated + async def stream_setmeta(self, identifier, meta): # deprecated """Set stream metadata.""" - return self._request(STREAM_SETMETA, identifier, 'meta', meta) + return await self._request(STREAM_SETMETA, identifier, 'meta', meta) - def stream_setproperty(self, identifier, stream_property, value): + async def stream_setproperty(self, identifier, stream_property, value): """Set stream metadata.""" self._version_check(STREAM_SETPROPERTY) - return self._request(STREAM_SETPROPERTY, identifier, parameters={ + return await self._request(STREAM_SETPROPERTY, identifier, parameters={ 'property': stream_property, 'value': value }) + async def stream_add_stream(self, stream_uri): + """Add a stream.""" + params = {"streamUri": stream_uri} + result, error = await self._transact(STREAM_ADDSTREAM, params) + if (isinstance(result, dict) and ("id" in result)): + self.synchronize((await self.status())[0]) + return result or error + + async def stream_remove_stream(self, identifier): + """Remove a Stream.""" + result = await self._request(STREAM_REMOVESTREAM, identifier) + if (isinstance(result, dict) and ("id" in result)): + self.synchronize((await self.status())[0]) + return result + def group(self, group_identifier): """Get a group.""" return self._groups[group_identifier] @@ -284,7 +340,6 @@ def synchronize(self, status): new_groups[group.get('id')].update(group) else: new_groups[group.get('id')] = Snapgroup(self, group) - _LOGGER.debug('group found: %s', new_groups[group.get('id')]) for client in group.get('clients'): if client.get('id') in self._clients: new_clients[client.get('id')] = self._clients[client.get('id')] @@ -292,10 +347,12 @@ def synchronize(self, status): else: new_clients[client.get('id')] = Snapclient(self, client) _LOGGER.debug('client found: %s', new_clients[client.get('id')]) + _LOGGER.debug('group found: %s', new_groups[group.get('id')]) self._groups = new_groups self._clients = new_clients self._streams = new_streams + # pylint: disable=too-many-arguments async def _request(self, method, identifier, key=None, value=None, parameters=None): """Perform request with identifier.""" params = {'id': identifier} @@ -303,10 +360,10 @@ async def _request(self, method, identifier, key=None, value=None, parameters=No params[key] = value if isinstance(parameters, dict): params.update(parameters) - result = await self._transact(method, params) + result, error = await self._transact(method, params) if isinstance(result, dict) and key in result: return result.get(key) - return result + return result or error def _on_server_connect(self): """Handle server connection.""" @@ -316,15 +373,13 @@ def _on_server_connect(self): def _on_server_disconnect(self, exception): """Handle server disconnection.""" - _LOGGER.debug('Server disconnected') + _LOGGER.debug('Server disconnected: %s', str(exception)) if self._on_disconnect_callback_func and callable(self._on_disconnect_callback_func): self._on_disconnect_callback_func(exception) - if not self._is_stopped: - self._do_disconnect() - self._protocol = None - self._transport = None - if self._reconnect: - self._reconnect_cb() + self._protocol = None + self._transport = None + if (not self._is_stopped) and self._reconnect: + self._reconnect_cb() def _on_server_update(self, data): """Handle server update.""" @@ -336,8 +391,8 @@ def _on_group_mute(self, data): """Handle group mute.""" group = self._groups.get(data.get('id')) group.update_mute(data) - for clientID in group.clients: - self._clients.get(clientID).callback() + for client_id in group.clients: + self._clients.get(client_id).callback() def _on_group_name_changed(self, data): """Handle group name changed.""" @@ -347,8 +402,8 @@ def _on_group_stream_changed(self, data): """Handle group stream change.""" group = self._groups.get(data.get('id')) group.update_stream(data) - for clientID in group.clients: - self._clients.get(clientID).callback() + for client_id in group.clients: + self._clients.get(client_id).callback() def _on_client_connect(self, data): """Handle client connect.""" @@ -397,8 +452,8 @@ def _on_stream_properties(self, data): for group in self._groups.values(): if group.stream == data.get('id'): group.callback() - for clientID in group.clients: - self._clients.get(clientID).callback() + for client_id in group.clients: + self._clients.get(client_id).callback() def _on_stream_update(self, data): """Handle stream update.""" @@ -408,8 +463,8 @@ def _on_stream_update(self, data): for group in self._groups.values(): if group.stream == data.get('id'): group.callback() - for clientID in group.clients: - self._clients.get(clientID).callback() + for client_id in group.clients: + self._clients.get(client_id).callback() def set_on_update_callback(self, func): """Set on update callback function.""" @@ -428,7 +483,7 @@ def set_new_client_callback(self, func): self._new_client_callback_func = func def __repr__(self): - """String representation.""" + """Return string representation.""" return f'Snapserver {self.version} ({self._host})' def _version_check(self, api_call): diff --git a/snapcast/control/stream.py b/snapcast/control/stream.py index a298671..d9a663e 100644 --- a/snapcast/control/stream.py +++ b/snapcast/control/stream.py @@ -46,6 +46,11 @@ def properties(self): """Get properties.""" return self._stream.get('properties') + @property + def path(self): + """Get stream path.""" + return self._stream.get('uri').get('path') + def update(self, data): """Update stream.""" self._stream = data @@ -65,7 +70,7 @@ def update_properties(self, data): self._stream['properties'] = data def __repr__(self): - """String representation.""" + """Return string representation.""" return f'Snapstream ({self.name})' def callback(self): diff --git a/snapcast/control/wsprotocol.py b/snapcast/control/wsprotocol.py new file mode 100644 index 0000000..2204735 --- /dev/null +++ b/snapcast/control/wsprotocol.py @@ -0,0 +1,66 @@ +"""Snapcast protocol.""" + +import asyncio +import json +import random + +SERVER_ONDISCONNECT = 'Server.OnDisconnect' + + +def jsonrpc_request(method, identifier, params=None): + """Produce a JSONRPC request.""" + return '{}\r\n'.format(json.dumps({ + 'id': identifier, + 'method': method, + 'params': params or {}, + 'jsonrpc': '2.0' + })).encode() + + +class SnapcastWebSocketProtocol(): + """Async Snapcast protocol.""" + + def __init__(self, websocket, callbacks): + """Initialize.""" + self._websocket = websocket + self._callbacks = callbacks + self._buffer = {} + + def message_received(self, message): + """Handle received data.""" + data = json.loads(message) + if not isinstance(data, list): + data = [data] + for item in data: + self.handle_data(item) + + def handle_data(self, data): + """Handle JSONRPC data.""" + if 'id' in data: + self.handle_response(data) + else: + self.handle_notification(data) + + def handle_response(self, data): + """Handle JSONRPC response.""" + identifier = data.get('id') + self._buffer[identifier]['data'] = data.get('result') + self._buffer[identifier]['error'] = data.get('error') + self._buffer[identifier]['flag'].set() + + def handle_notification(self, data): + """Handle JSONRPC notification.""" + if data.get('method') in self._callbacks: + self._callbacks.get(data.get('method'))(data.get('params')) + + async def request(self, method, params): + """Send a JSONRPC request.""" + identifier = random.randint(1, 1000) + await self._websocket.send(jsonrpc_request(method, identifier, params)) + self._buffer[identifier] = {'flag': asyncio.Event()} + await self._buffer[identifier]['flag'].wait() + result = self._buffer[identifier]['data'] + error = self._buffer[identifier]['error'] + del self._buffer[identifier]['data'] + del self._buffer[identifier]['error'] + return (result, error) diff --git a/tests/test_group.py b/tests/test_group.py index 5c3563b..bd99ec2 100644 --- a/tests/test_group.py +++ b/tests/test_group.py @@ -28,21 +28,26 @@ def setUp(self): client.callback = MagicMock() client.update_volume = MagicMock() client.friendly_name = 'A' + client.identifier = 'a' server.streams = [stream] server.stream = MagicMock(return_value=stream) server.client = MagicMock(return_value=client) + server.clients = [client] self.group = Snapgroup(server, data) def test_init(self): self.assertEqual(self.group.identifier, 'test') self.assertEqual(self.group.name, '') - self.assertEqual(self.group.friendly_name, 'A+A') + self.assertEqual(self.group.friendly_name, 'A') self.assertEqual(self.group.stream, 'test stream') self.assertEqual(self.group.muted, False) self.assertEqual(self.group.volume, 50) self.assertEqual(self.group.clients, ['a', 'b']) self.assertEqual(self.group.stream_status, 'playing') + def test_repr(self): + self.assertEqual(self.group.__repr__(), 'Snapgroup (A, test)') + def test_update(self): self.group.update({ 'stream_id': 'other stream' diff --git a/tests/test_server.py b/tests/test_server.py index 5318efc..01ac701 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -101,11 +101,11 @@ { 'clients': [] } - ], - 'server': SERVER_STATUS, # DeleteClient calls synchronize - 'streams': [ - ] - } + ], + 'server': SERVER_STATUS, # DeleteClient calls synchronize + 'streams': [ + ] + } }, 'Group.GetStatus': { 'group': { @@ -124,12 +124,18 @@ 'Stream.SetMeta': { 'foo': 'bar' }, - 'Stream.SetProperty': 'ok' + 'Stream.SetProperty': 'ok', + 'Stream.AddStream': { + 'id': 'stream 2' + }, + 'Stream.RemoveStream': { + 'id': 'stream 2' + }, } def mock_transact(key): - return AsyncMock(return_value=return_values[key]) + return AsyncMock(return_value=(return_values[key], None)) class TestSnapserver(unittest.TestCase): @@ -143,6 +149,21 @@ def setUp(self): self.server = self._run(create_server(self.loop, 'abcd')) self.server.synchronize(return_values.get('Server.GetStatus')) + @mock.patch.object(Snapserver, 'status', new=AsyncMock( + return_value=(None, {"code": -1, "message": "failed"}))) + @mock.patch.object(Snapserver, '_do_connect', new=AsyncMock()) + @mock.patch.object(Snapserver, 'stop', new=mock.MagicMock()) + def test_start_fail(self): + with self.assertRaises(OSError): + self._run(self.server.start()) + + @mock.patch.object(Snapserver, '_transact', new=mock_transact('Server.GetStatus')) + @mock.patch.object(Snapserver, '_do_connect', new=AsyncMock()) + def test_start(self): + self.server._version = None + self._run(self.server.start()) + self.assertEqual(self.server.version, '0.26.0') + def test_init(self): self.assertEqual(self.server.version, '0.26.0') self.assertEqual(len(self.server.clients), 1) @@ -154,12 +175,12 @@ def test_init(self): @mock.patch.object(Snapserver, '_transact', new=mock_transact('Server.GetStatus')) def test_status(self): - status = self._run(self.server.status()) + status, _ = self._run(self.server.status()) self.assertEqual(status['server']['server']['snapserver']['version'], '0.26.0') @mock.patch.object(Snapserver, '_transact', new=mock_transact('Server.GetRPCVersion')) def test_rpc_version(self): - version = self._run(self.server.rpc_version()) + version, _ = self._run(self.server.rpc_version()) self.assertEqual(version, {'major': 2, 'minor': 0, 'patch': 1}) @mock.patch.object(Snapserver, '_transact', new=mock_transact('Client.SetName')) @@ -213,6 +234,18 @@ def test_stream_setproperty(self): result = self._run(self.server.stream_setproperty('stream', 'foo', 'bar')) self.assertEqual(result, 'ok') + @mock.patch.object(Snapserver, '_transact', new=mock_transact('Stream.AddStream')) + @mock.patch.object(Snapserver, 'synchronize', new=MagicMock()) + def test_stream_addstream(self): + result = self._run(self.server.stream_add_stream('pipe:///tmp/test?name=stream 2')) + self.assertDictEqual(result, {'id': 'stream 2'}) + + @mock.patch.object(Snapserver, '_transact', new=mock_transact('Stream.RemoveStream')) + @mock.patch.object(Snapserver, 'synchronize', new=MagicMock()) + def test_stream_removestream(self): + result = self._run(self.server.stream_remove_stream('stream 2')) + self.assertDictEqual(result, {'id': 'stream 2'}) + def test_synchronize(self): status = copy.deepcopy(return_values.get('Server.GetStatus')) status['server']['server']['snapserver']['version'] = '0.12' diff --git a/tests/test_stream.py b/tests/test_stream.py index 8e20c91..d4f3623 100644 --- a/tests/test_stream.py +++ b/tests/test_stream.py @@ -21,6 +21,7 @@ def setUp(self): 'id': 'test', 'status': 'playing', 'uri': { + 'path': '/tmp/snapfifo', 'query': { 'name': '' } @@ -40,9 +41,11 @@ def test_init(self): self.assertEqual(self.stream.status, 'playing') self.assertEqual(self.stream.name, '') self.assertEqual(self.stream.friendly_name, 'test') + self.assertEqual(self.stream.path, '/tmp/snapfifo') self.assertDictEqual(self.stream_meta.meta, {'TITLE': 'Happy!'}) self.assertDictEqual(self.stream.properties['metadata'], {'title': 'Happy!'}) - self.assertDictEqual(self.stream.properties, {'canControl': False, 'metadata': {'title': 'Happy!',}}) + self.assertDictEqual(self.stream.properties, + {'canControl': False, 'metadata': {'title': 'Happy!'}}) self.assertDictEqual(self.stream.metadata, {'title': 'Happy!'}) def test_update(self): diff --git a/tests/test_wsprotocol.py b/tests/test_wsprotocol.py new file mode 100644 index 0000000..2205ce6 --- /dev/null +++ b/tests/test_wsprotocol.py @@ -0,0 +1,51 @@ +import unittest +import asyncio +from unittest.mock import MagicMock, patch + +from snapcast.control.protocol import jsonrpc_request +from snapcast.control.wsprotocol import SnapcastWebSocketProtocol + +class TestSnapcastProtocol(unittest.TestCase): + def setUp(self): + self.websocket = MagicMock() + self.callbacks = { + 'Server.OnDisconnect': MagicMock() + } + self.protocol = SnapcastWebSocketProtocol(self.websocket, self.callbacks) + + def test_jsonrpc_request(self): + method = 'Server.GetStatus' + identifier = 123 + params = {'param1': 'value1'} + expected_request = '{"id": 123, "method": "Server.GetStatus", "params": {"param1": "value1"}, "jsonrpc": "2.0"}\r\n'.encode() + request = jsonrpc_request(method, identifier, params) + self.assertEqual(request, expected_request) + + def test_handle_response(self): + response_data = { + 'id': 123, + 'result': {'status': 'ok'}, + 'error': None + } + self.protocol.handle_data(response_data) + self.assertTrue(self.protocol._buffer[123]['flag'].is_set()) + self.assertEqual(self.protocol._buffer[123]['data'], {'status': 'ok'}) + self.assertIsNone(self.protocol._buffer[123]['error']) + + def test_handle_notification(self): + notification_data = { + 'method': 'Server.OnDisconnect', + 'params': {'client': 'client1'} + } + self.protocol.handle_data(notification_data) + self.callbacks['Server.OnDisconnect'].assert_called_with({'client': 'client1'}) + + @patch('snapcast.control.protocol.jsonrpc_request') + def test_request(self, mock_jsonrpc_request): + mock_jsonrpc_request.return_value = b'{"id": 123, "method": "Server.GetStatus", "params": {}, "jsonrpc": "2.0"}\r\n' + self.protocol._buffer[123] = {'flag': asyncio.Event(), 'data': {'status': 'ok'}, 'error': None} + + loop = asyncio.new_event_loop() + result, error = loop.run_until_complete(self.protocol.request('Server.GetStatus', {})) + self.assertEqual(result, {'status': 'ok'}) + self.assertIsNone(error)