Skip to content

Commit

Permalink
Setup OTA Provider App automatically when necessary
Browse files Browse the repository at this point in the history
Start and commission OTA Provider App when necessary. Use random
discriminator and passcode. Store the Node ID of the OTA Provider
App once setup for fast re-use.
  • Loading branch information
agners committed May 17, 2024
1 parent 4e75b77 commit 745ff7d
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 45 deletions.
7 changes: 7 additions & 0 deletions matter_server/server/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,12 @@
default=None,
help="Directory where PAA root certificates are stored.",
)
parser.add_argument(
"--ota-provider-dir",
type=str,
default=None,
help="Directory where OTA Provider stores software updates and configuration.",
)

args = parser.parse_args()

Expand Down Expand Up @@ -186,6 +192,7 @@ def main() -> None:
args.listen_address,
args.primary_interface,
args.paa_root_cert_dir,
args.ota_provider_dir,
)

async def handle_stop(loop: asyncio.AbstractEventLoop) -> None:
Expand Down
2 changes: 2 additions & 0 deletions matter_server/server/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,5 @@
.parent.resolve()
.joinpath("credentials/development/paa-root-certs")
)

DEFAULT_OTA_PROVIDER_DIR: Final[pathlib.Path] = pathlib.Path().cwd().joinpath("updates")
87 changes: 83 additions & 4 deletions matter_server/server/device_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,12 @@
import time
from typing import TYPE_CHECKING, Any, TypeVar, cast

from chip.clusters import Attribute, Objects as Clusters
from chip.clusters import Attribute, Objects as Clusters, Types
from chip.clusters.Attribute import ValueDecodeFailure
from chip.clusters.ClusterObjects import ALL_ATTRIBUTES, ALL_CLUSTERS, Cluster
from chip.discovery import DiscoveryType
from chip.exceptions import ChipStackError
from chip.interaction_model import Status
from zeroconf import BadTypeInNameException, IPVersion, ServiceStateChange, Zeroconf
from zeroconf.asyncio import AsyncServiceBrowser, AsyncServiceInfo, AsyncZeroconf

Expand Down Expand Up @@ -138,7 +139,7 @@ def __init__(
self._node_lock: dict[int, asyncio.Lock] = {}
self._ota_provider: ExternalOtaProvider | None = None

async def initialize(self, paa_root_cert_dir: Path) -> None:
async def initialize(self, paa_root_cert_dir: Path, ota_provider_dir: Path) -> None:
"""Async initialize of controller."""
# (re)fetch all PAA certificates once at startup
# NOTE: this must be done before initializing the controller
Expand All @@ -152,7 +153,8 @@ async def initialize(self, paa_root_cert_dir: Path) -> None:
int, await self._call_sdk(self.chip_controller.GetCompressedFabricId)
)
self.fabric_id_hex = hex(self.compressed_fabric_id)[2:]
self._ota_provider = ExternalOtaProvider()
self._ota_provider = ExternalOtaProvider(ota_provider_dir)
await self._ota_provider.initialize()
LOGGER.debug("CHIP Device Controller Initialized")

async def start(self) -> None:
Expand Down Expand Up @@ -959,17 +961,94 @@ async def update_node(self, node_id: int) -> dict | None:
# Add to OTA provider
await self._ota_provider.download_update(update)

ota_provider_node_id = self._ota_provider.get_node_id()
if ota_provider_node_id not in self._nodes:
LOGGER.warning(
"OTA Provider node id %d no longer exists! Resetting...",
ota_provider_node_id,
)
await self._ota_provider.reset()
ota_provider_node_id = None

# Make sure any previous instances get stopped
await self._ota_provider.stop()
self._ota_provider.start()

# Wait for OTA provider to be ready
# TODO: Detect when OTA provider is ready
await asyncio.sleep(2)

if not ota_provider_node_id:
# The OTA Provider has not been commissioned yet, let's do it now.
LOGGER.info("Commissioning the built-in OTA Provider App.")
try:
ota_provider_node = await self.commission_on_network(
self._ota_provider.get_passcode(),
# TODO: Filtering by long discriminator seems broken
# filter_type=FilterType.LONG_DISCRIMINATOR,
# filter=self._ota_provider.get_descriminator(),
)
ota_provider_node_id = ota_provider_node.node_id
except NodeCommissionFailed:
LOGGER.error("Failed to commission OTA Provider App!")
return None
LOGGER.info(
"OTA Provider App commissioned with node id %d.",
ota_provider_node_id,
)

# Adjust ACL of OTA Requestor such that Node peer-to-peer communication
# is allowed.
try:
read_result = await self.chip_controller.ReadAttribute(
ota_provider_node_id, [(0, Clusters.AccessControl.Attributes.Acl)]
)
acl_list = cast(
list,
read_result[0][Clusters.AccessControl][
Clusters.AccessControl.Attributes.Acl
],
)

# Add new ACL entry...
acl_list.append(
Clusters.AccessControl.Structs.AccessControlEntryStruct(
fabricIndex=1,
privilege=3,
authMode=2,
subjects=Types.NullValue,
targets=[
Clusters.AccessControl.Structs.AccessControlTargetStruct(
cluster=41, endpoint=0, deviceType=Types.NullValue
)
],
)
)

# And write. This is persistent, so only need to be done after we commissioned
# the OTA Provider App.
write_result: Attribute.AttributeWriteResult = (
await self.chip_controller.WriteAttribute(
ota_provider_node_id,
[(0, Clusters.AccessControl.Attributes.Acl(acl_list))],
)
)
if write_result[0].Status != Status.Success:
logging.error("Failed writing adjusted OTA Provider App ACL.")
await self.remove_node(ota_provider_node_id)
return None
except ChipStackError as ex:
logging.exception("Failed adjusting OTA Provider App ACL.", exc_info=ex)
await self.remove_node(ota_provider_node_id)
else:
self._ota_provider.set_node_id(ota_provider_node_id)

# Notify node about the new update!
await self.chip_controller.SendCommand(
nodeid=node_id,
endpoint=0,
payload=Clusters.OtaSoftwareUpdateRequestor.Commands.AnnounceOTAProvider(
providerNodeID=32,
providerNodeID=ota_provider_node_id,
vendorID=0, # TODO: Use Server Vendor ID
announcementReason=Clusters.OtaSoftwareUpdateRequestor.Enums.AnnouncementReasonEnum.kUpdateAvailable,
endpoint=0,
Expand Down
153 changes: 114 additions & 39 deletions matter_server/server/ota/provider.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import json
import logging
from pathlib import Path
import secrets
from typing import TYPE_CHECKING, Final
from urllib.parse import unquote, urlparse

Expand Down Expand Up @@ -37,9 +38,12 @@ class DeviceSoftwareVersionModel: # pylint: disable=C0103


@dataclass
class UpdateFile: # pylint: disable=C0103
class OtaProviderImageList: # pylint: disable=C0103
"""Update File for OTA Provider JSON descriptor file."""

otaProviderDiscriminator: int
otaProviderPasscode: int
otaProviderNodeId: int | None
deviceSoftwareVersionModel: list[DeviceSoftwareVersionModel]


Expand All @@ -50,23 +54,103 @@ class ExternalOtaProvider:
for devices.
"""

def __init__(self) -> None:
def __init__(self, ota_provider_dir: Path) -> None:
"""Initialize the OTA provider."""
self._ota_provider_dir: Path = ota_provider_dir
self._ota_provider_image_list_file: Path = ota_provider_dir / "updates.json"
self._ota_provider_image_list: OtaProviderImageList | None = None
self._ota_provider_proc: Process | None = None
self._ota_provider_task: asyncio.Task | None = None

async def initialize(self) -> None:
"""Initialize OTA Provider."""

loop = asyncio.get_event_loop()

# Take existence of image list file as indicator if we need to initialize the
# OTA Provider.
if not await loop.run_in_executor(
None, self._ota_provider_image_list_file.exists
):
await loop.run_in_executor(
None, functools.partial(DEFAULT_UPDATES_PATH.mkdir, exist_ok=True)
)

# Initialize with random data. Node ID will get written once paired by
# device controller.
self._ota_provider_image_list = OtaProviderImageList(
otaProviderDiscriminator=secrets.randbelow(2**12),
otaProviderPasscode=secrets.randbelow(2**21),
otaProviderNodeId=None,
deviceSoftwareVersionModel=[],
)
else:

def _read_update_json(
update_json_path: Path,
) -> None | OtaProviderImageList:
with open(update_json_path, "r") as json_file:
data = json.load(json_file)
return dataclass_from_dict(OtaProviderImageList, data)

self._ota_provider_image_list = await loop.run_in_executor(
None, _read_update_json, self._ota_provider_image_list_file
)

def _get_ota_provider_image_list(self) -> OtaProviderImageList:
if self._ota_provider_image_list is None:
raise RuntimeError("OTA provider image list not initialized.")
return self._ota_provider_image_list

def get_node_id(self) -> int | None:
"""Get Node ID of the OTA Provider App."""

return self._get_ota_provider_image_list().otaProviderNodeId

def get_descriminator(self) -> int:
"""Return OTA Provider App discriminator."""

return self._get_ota_provider_image_list().otaProviderDiscriminator

def get_passcode(self) -> int:
"""Return OTA Provider App passcode."""

return self._get_ota_provider_image_list().otaProviderPasscode

def set_node_id(self, node_id: int) -> None:
"""Set Node ID of the OTA Provider App."""

self._get_ota_provider_image_list().otaProviderNodeId = node_id

async def _start_ota_provider(self) -> None:
# TODO: Randomize discriminator
def _write_ota_provider_image_list_json(
ota_provider_image_list_file: Path,
ota_provider_image_list: OtaProviderImageList,
) -> None:
update_file_dict = asdict(ota_provider_image_list)
with open(ota_provider_image_list_file, "w") as json_file:
json.dump(update_file_dict, json_file, indent=4)

loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
_write_ota_provider_image_list_json,
self._ota_provider_image_list_file,
self._get_ota_provider_image_list(),
)

ota_provider_cmd = [
"chip-ota-provider-app",
"--discriminator",
"22",
str(self._get_ota_provider_image_list().otaProviderDiscriminator),
"--passcode",
str(self._get_ota_provider_image_list().otaProviderPasscode),
"--secured-device-port",
"5565",
"--KVS",
"/data/chip_kvs_provider",
str(self._ota_provider_dir / "chip_kvs_ota_provider"),
"--otaImageList",
str(DEFAULT_UPDATES_PATH / "updates.json"),
str(self._ota_provider_image_list_file),
]

LOGGER.info("Starting OTA Provider")
Expand All @@ -80,40 +164,41 @@ def start(self) -> None:
loop = asyncio.get_event_loop()
self._ota_provider_task = loop.create_task(self._start_ota_provider())

async def reset(self) -> None:
"""Reset the OTA Provider App state."""

def _remove_update_data(ota_provider_dir: Path) -> None:
for path in ota_provider_dir.iterdir():
if not path.is_dir():
path.unlink()

loop = asyncio.get_event_loop()
await loop.run_in_executor(None, _remove_update_data, self._ota_provider_dir)

await self.initialize()

async def stop(self) -> None:
"""Stop the OTA Provider."""
if self._ota_provider_proc:
LOGGER.info("Terminating OTA Provider")
self._ota_provider_proc.terminate()
loop = asyncio.get_event_loop()
try:
await loop.run_in_executor(None, self._ota_provider_proc.terminate)
except ProcessLookupError as ex:
LOGGER.warning("Stopping OTA Provider failed with error:", exc_info=ex)
if self._ota_provider_task:
await self._ota_provider_task

async def add_update(self, update_desc: dict, ota_file: Path) -> None:
"""Add update to the OTA provider."""

update_json_path = DEFAULT_UPDATES_PATH / "updates.json"

def _read_update_json(update_json_path: Path) -> None | UpdateFile:
if not update_json_path.exists():
return None

with open(update_json_path, "r") as json_file:
data = json.load(json_file)
return dataclass_from_dict(UpdateFile, data)

loop = asyncio.get_running_loop()
update_file = await loop.run_in_executor(
None, _read_update_json, update_json_path
)

if not update_file:
update_file = UpdateFile(deviceSoftwareVersionModel=[])

local_ota_url = str(ota_file)
for i, device_software in enumerate(update_file.deviceSoftwareVersionModel):
for i, device_software in enumerate(
self._get_ota_provider_image_list().deviceSoftwareVersionModel
):
if device_software.otaURL == local_ota_url:
LOGGER.debug("Device software entry exists already, replacing!")
del update_file.deviceSoftwareVersionModel[i]
del self._get_ota_provider_image_list().deviceSoftwareVersionModel[i]

# Convert to OTA Requestor descriptor file
new_device_software = DeviceSoftwareVersionModel(
Expand All @@ -127,18 +212,8 @@ def _read_update_json(update_json_path: Path) -> None | UpdateFile:
maxApplicableSoftwareVersion=update_desc["maxApplicableSoftwareVersion"],
otaURL=local_ota_url,
)
update_file.deviceSoftwareVersionModel.append(new_device_software)

def _write_update_json(update_json_path: Path, update_file: UpdateFile) -> None:
update_file_dict = asdict(update_file)
with open(update_json_path, "w") as json_file:
json.dump(update_file_dict, json_file, indent=4)

await loop.run_in_executor(
None,
_write_update_json,
update_json_path,
update_file,
self._get_ota_provider_image_list().deviceSoftwareVersionModel.append(
new_device_software
)

async def download_update(self, update_desc: dict) -> None:
Expand Down
Loading

0 comments on commit 745ff7d

Please sign in to comment.