Skip to content

Commit

Permalink
Support new DTLS Identity method (#92)
Browse files Browse the repository at this point in the history
  • Loading branch information
lwis authored Nov 1, 2017
1 parent bcafe2f commit ed1fea1
Show file tree
Hide file tree
Showing 10 changed files with 203 additions and 83 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,5 @@ dist/
# Unit test / coverage reports
.coverage
.tox

gateway_psk.txt
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
[![Coverage Status](https://coveralls.io/repos/github/ggravlingen/pytradfri/badge.svg?branch=master)](https://coveralls.io/github/ggravlingen/pytradfri?branch=master)
[![PyPI version](https://badge.fury.io/py/pytradfri.svg)](https://badge.fury.io/py/pytradfri)

**NB:** Latest Gateway version tested and working - 1.2.42.

Python class to communicate with the [IKEA Trådfri](http://www.ikea.com/us/en/catalog/products/00337813/) (Tradfri) ZigBee-based Gateway. Using this library you can, by communicating with the gateway, control IKEA lights (including the RGB ones) and also Philips Hue bulbs. Some of the features include:

- Get information on the gateway
Expand Down
21 changes: 16 additions & 5 deletions example_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import sys

from pytradfri import Gateway
from pytradfri.api.aiocoap_api import api_factory
from pytradfri.api.aiocoap_api import APIFactory

root = logging.getLogger()
root.setLevel(logging.INFO)
Expand All @@ -35,7 +35,17 @@
def run():
# Assign configuration variables.
# The configuration check takes care they are present.
api = yield from api_factory(sys.argv[1], sys.argv[2])
api_factory = APIFactory(sys.argv[1])
with open('gateway_psk.txt', 'a+') as file:
file.seek(0)
psk = file.read()
if psk:
api_factory.psk = psk.strip()
else:
psk = yield from api_factory.generate_psk(sys.argv[2])
print('Generated PSK: ', psk)
file.write(psk)
api = api_factory.request

gateway = Gateway()

Expand All @@ -45,9 +55,6 @@ def run():

lights = [dev for dev in devices if dev.has_light_control]

tasks_command = gateway.get_smart_tasks()
tasks = yield from api(tasks_command)

# Print all lights
print(lights)

Expand Down Expand Up @@ -87,6 +94,10 @@ def observe_err_callback(err):
color_command = light.light_control.set_hex_color('efd275')
yield from api(color_command)

tasks_command = gateway.get_smart_tasks()
tasks_commands = yield from api(tasks_command)
tasks = yield from api(tasks_commands)

# Example 6: Return the transition time (in minutes) for task#1
if tasks:
print(tasks[0].task_control.tasks[0].transition_time)
Expand Down
21 changes: 16 additions & 5 deletions example_sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import time

from pytradfri import Gateway
from pytradfri.api.libcoap_api import api_factory
from pytradfri.api.libcoap_api import APIFactory


def observe(api, device):
Expand All @@ -39,7 +39,17 @@ def worker():
def run():
# Assign configuration variables.
# The configuration check takes care they are present.
api = api_factory(sys.argv[1], sys.argv[2])
api_factory = APIFactory(sys.argv[1])
with open('gateway_psk.txt', 'a+') as file:
file.seek(0)
psk = file.read()
if psk:
api_factory.psk = psk.strip()
else:
psk = api_factory.generate_psk(sys.argv[2])
print('Generated PSK: ', psk)
file.write(psk)
api = api_factory.request

gateway = Gateway()

Expand All @@ -49,9 +59,6 @@ def run():

lights = [dev for dev in devices if dev.has_light_control]

tasks_command = gateway.get_smart_tasks()
tasks = api(tasks_command)

# Print all lights
print(lights)

Expand All @@ -78,6 +85,10 @@ def run():
color_command = light.light_control.set_hex_color('efd275')
api(color_command)

tasks_command = gateway.get_smart_tasks()
tasks_commands = api(tasks_command)
tasks = api(tasks_commands)

# Example 6: Return the transition time (in minutes) for task#1
if tasks:
print(tasks[0].task_control.tasks[0].transition_time)
Expand Down
140 changes: 94 additions & 46 deletions pytradfri/api/aiocoap_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,130 +9,162 @@
from aiocoap.transports import tinydtls

from ..error import ClientError, ServerError, RequestTimeout
from ..command import Command
from ..gateway import Gateway

_LOGGER = logging.getLogger(__name__)


class PatchedDTLSSecurityStore:
"""Patched DTLS store in lieu of impl."""

SECRET_PSK = None
IDENTITY = None
KEY = None

def _get_psk(self, host, port):
return b"Client_identity", PatchedDTLSSecurityStore.SECRET_PSK
return PatchedDTLSSecurityStore.IDENTITY, PatchedDTLSSecurityStore.KEY


tinydtls.DTLSSecurityStore = PatchedDTLSSecurityStore


@asyncio.coroutine
def api_factory(host, security_code, loop=None):
"""Generate a request method."""
if loop is None:
loop = asyncio.get_event_loop()
class APIFactory:
def __init__(self, host, psk_id='pytradfri', psk=None, loop=None):
self._psk = psk
self._host = host
self._psk_id = psk_id
self._loop = loop
self._observations_err_callbacks = []
self._protocol = None

PatchedDTLSSecurityStore.SECRET_PSK = security_code.encode('utf-8')
if self._loop is None:
self._loop = asyncio.get_event_loop()

_observations_err_callbacks = []
_protocol = yield from Context.create_client_context(loop=loop)
PatchedDTLSSecurityStore.IDENTITY = self._psk_id.encode('utf-8')

if self._psk:
PatchedDTLSSecurityStore.KEY = self._psk.encode('utf-8')

@property
def psk_id(self):
return self._psk_id

@psk_id.setter
def psk_id(self, value):
self._psk_id = value
PatchedDTLSSecurityStore.IDENTITY = self._psk_id.encode('utf-8')

@property
def psk(self):
return self._psk

@psk.setter
def psk(self, value):
self._psk = value
PatchedDTLSSecurityStore.KEY = self._psk.encode('utf-8')

@asyncio.coroutine
def _get_protocol():
def _get_protocol(self):
"""Get the protocol for the request."""
nonlocal _protocol
if not _protocol:
_protocol = yield from Context.create_client_context(loop=loop)
return _protocol
if not self._protocol:
self._protocol = yield from Context.create_client_context(
loop=self._loop)
return self._protocol

@asyncio.coroutine
def _reset_protocol(exc):
def _reset_protocol(self, exc=None):
"""Reset the protocol if an error occurs.
This can be removed when chrysn/aiocoap#79 is closed."""
# Be responsible and clean up.
protocol = yield from _get_protocol()
protocol = yield from self._get_protocol()
yield from protocol.shutdown()
nonlocal _protocol
_protocol = None
self._protocol = None
# Let any observers know the protocol has been shutdown.
nonlocal _observations_err_callbacks
for ob_error in _observations_err_callbacks:
for ob_error in self._observations_err_callbacks:
ob_error(exc)
_observations_err_callbacks.clear()
self._observations_err_callbacks.clear()

@asyncio.coroutine
def _get_response(msg):
def _get_response(self, msg):
"""Perform the request, get the response."""
try:
protocol = yield from _get_protocol()
protocol = yield from self._get_protocol()
pr = protocol.request(msg)
r = yield from pr.response
return pr, r
except ConstructionRenderableError as e:
raise ClientError("There was an error with the request.", e)
except RequestTimedOut as e:
yield from _reset_protocol(e)
yield from self._reset_protocol(e)
raise RequestTimeout('Request timed out.', e)
except Error as e:
yield from _reset_protocol(e)
yield from self._reset_protocol(e)
raise ServerError("There was an error with the request.", e)

@asyncio.coroutine
def _execute(api_command):
def _execute(self, api_command):
"""Execute the command."""
if api_command.observe:
yield from _observe(api_command)
yield from self._observe(api_command)
return

method = api_command.method
path = api_command.path
data = api_command.data
parse_json = api_command.parse_json
url = api_command.url(host)
url = api_command.url(self._host)

kwargs = {}

if data is not None:
kwargs['payload'] = json.dumps(data).encode('utf-8')
_LOGGER.debug('Executing %s %s %s: %s', host, method, path, data)
_LOGGER.debug('Executing %s %s %s: %s', self._host, method, path,
data)
else:
_LOGGER.debug('Executing %s %s %s', host, method, path)
_LOGGER.debug('Executing %s %s %s', self._host, method, path)

api_method = Code.GET
if method == 'put':
api_method = Code.PUT
elif method == 'post':
api_method = Code.POST
elif method == 'delete':
api_method = Code.DELETE
elif method == 'fetch':
api_method = Code.FETCH
elif method == 'patch':
api_method = Code.PATCH

msg = Message(code=api_method, uri=url, **kwargs)

_, res = yield from _get_response(msg)
_, res = yield from self._get_response(msg)

api_command.result = _process_output(res, parse_json)

return api_command.result

@asyncio.coroutine
def request(api_commands):
def request(self, api_commands):
"""Make a request."""
if not isinstance(api_commands, list):
result = yield from _execute(api_commands)
result = yield from self._execute(api_commands)
return result

commands = (_execute(api_command) for api_command in api_commands)
command_results = yield from asyncio.gather(*commands, loop=loop)
commands = (self._execute(api_command) for api_command in api_commands)
command_results = yield from asyncio.gather(*commands, loop=self._loop)

return command_results

@asyncio.coroutine
def _observe(api_command):
def _observe(self, api_command):
"""Observe an endpoint."""
duration = api_command.observe_duration
url = api_command.url(host)
url = api_command.url(self._host)
err_callback = api_command.err_callback

msg = Message(code=Code.GET, uri=url, observe=duration)

# Note that this is necessary to start observing
pr, r = yield from _get_response(msg)
pr, r = yield from self._get_response(msg)

api_command.result = _process_output(r)

Expand All @@ -145,13 +177,29 @@ def error_callback(ex):
ob = pr.observation
ob.register_callback(success_callback)
ob.register_errback(error_callback)
nonlocal _observations_err_callbacks
_observations_err_callbacks.append(ob.error)

# This will cause a RequestError to be raised if credentials invalid
yield from request(Command('get', ['status']))
self._observations_err_callbacks.append(ob.error)

return request
@asyncio.coroutine
def generate_psk(self, security_key):
"""
Generate and set a psk from the security key.
"""
if not self._psk:
PatchedDTLSSecurityStore.IDENTITY = 'Client_identity'.encode(
'utf-8')
PatchedDTLSSecurityStore.KEY = security_key.encode('utf-8')

command = Gateway().generate_psk(self._psk_id)
self._psk = yield from self.request(command)

PatchedDTLSSecurityStore.IDENTITY = self._psk_id.encode('utf-8')
PatchedDTLSSecurityStore.KEY = self._psk.encode('utf-8')

# aiocoap has now cached our psk, so it must be reset.
# We also no longer need the protocol, so this will shutdown that.
yield from self._reset_protocol()

return self._psk


def _process_output(res, parse_json=True):
Expand Down
Loading

0 comments on commit ed1fea1

Please sign in to comment.