Skip to content

Commit

Permalink
Add API for running test flows (#91)
Browse files Browse the repository at this point in the history
* Add API for running test flows
  • Loading branch information
adelefelicia authored and Haakon Karstensen committed Nov 5, 2024
1 parent a18ab20 commit 7c43156
Show file tree
Hide file tree
Showing 9 changed files with 255 additions and 0 deletions.
110 changes: 110 additions & 0 deletions server/server_comm/flow_parser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from collections import defaultdict, deque
from typing import Any


class FlowParser:

def __init__(self, flow: Any):
self.flow = flow
self.nodes = list(flow.nodes.all())
self.edges = list(flow.edges.all())

self.outgoing_edges = defaultdict(list)
self.incoming_edges_counter = {node.id: 0 for node in self.nodes}

for edge in self.edges:
self.outgoing_edges[edge.source.id].append(edge.target.id)
self.incoming_edges_counter[edge.target.id] += 1

def get_start_node(self) -> int:
"""
Find potential starting nodes (those with no incoming edges)
and return the one with the largest reachable subgraph.
"""
start_candidates = [
node.id
for node in self.nodes
if self.incoming_edges_counter[node.id] == 0
]
max_connected_nodes = 0
best_start_node = None

for node_id in start_candidates:
connected_nodes = self.bfs_connected_count(node_id)
if connected_nodes > max_connected_nodes:
max_connected_nodes = connected_nodes
best_start_node = node_id
elif connected_nodes == max_connected_nodes:
# Tiebreaker is smallest ID
best_start_node = min(best_start_node, node_id)

if best_start_node is None:
raise ValueError(
"No start node found. Check for cycles and missing edges."
)

return best_start_node

def bfs_connected_count(self, start_node_id: int) -> int:
"""
Count the number of nodes reachable from start_node_id using BFS.
"""
visited = set()
queue = deque([start_node_id])

while queue:
node_id = queue.popleft()
if node_id not in visited:
visited.add(node_id)
for neighbor in self.outgoing_edges[node_id]:
if neighbor not in visited:
queue.append(neighbor)

return len(visited)

def get_execution_order(self) -> list[list[Any]]:
"""
Perform a topological sort with Kahn's algorith,
grouping nodes by levels to show parallelism, with cycle detection.
:return: List[List[Node]] - A list of lists,
where each sublist contains nodes that can be executed in parallel.
"""
start_node_id = self.get_start_node()
if not start_node_id:
return []

queue = deque([start_node_id])
visited = set()
level_order = []
nodes_visited_count = 0

while queue:
level_size = len(queue)
current_level = []

for _ in range(level_size):
node_id = queue.popleft()
if node_id in visited:
continue

visited.add(node_id)
current_level.append(node_id)
nodes_visited_count += 1

for neighbor in self.outgoing_edges[node_id]:
self.incoming_edges_counter[neighbor] -= 1
if (
self.incoming_edges_counter[neighbor] == 0
and neighbor not in visited
):
queue.append(neighbor)

if current_level:
level_order.append(current_level)

ordered_nodes = [
[self.flow.nodes.get(id=node_id) for node_id in level]
for level in level_order
]
return ordered_nodes
1 change: 1 addition & 0 deletions server/server_comm/server_comm/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
"data_manager",
"device_connection",
"corsheaders",
"test_runner",
]

MIDDLEWARE = [
Expand Down
1 change: 1 addition & 0 deletions server/server_comm/server_comm/urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,5 @@
path("admin/", admin.site.urls),
path("data_manager/", include("data_manager.urls")),
path("device_connection/", include("device_connection.urls")),
path("test_runner/", include("test_runner.urls")),
]
Empty file.
6 changes: 6 additions & 0 deletions server/server_comm/test_runner/apps.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from django.apps import AppConfig


class TestRunnerConfig(AppConfig):
default_auto_field = 'django.db.models.BigAutoField'
name = 'test_runner'
Empty file.
Empty file.
8 changes: 8 additions & 0 deletions server/server_comm/test_runner/urls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
from django.urls import path

from .views import RunTestFlow

urlpatterns = [
path('run/<int:flow_id>/', RunTestFlow.as_view(), name='run-test-flow'),
path('run/', RunTestFlow.as_view(), name='run-all-test-flows'),
]
129 changes: 129 additions & 0 deletions server/server_comm/test_runner/views.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from urllib.parse import urlencode

import requests
from data_manager.models import Flow, Node
from django.shortcuts import get_object_or_404
from django.urls import reverse
from flow_parser import FlowParser
from rest_framework import status
from rest_framework.response import Response
from rest_framework.views import APIView


class RunTestFlow(APIView):
"""
API endpoint to run a specific test flow by flow_id
or all test flows if no flow_id is specified.
"""

def post(self, request, flow_id=None):
if flow_id:
flow = get_object_or_404(Flow, id=flow_id)
test_flows = [flow]
else:
# Run all test flows if no flow_id is specified
test_flows = Flow.objects.all()

results = []

self.test_setup(flow_id)

for flow in test_flows:
flow_result = {
"flow_name": flow.name,
"nodes_executed": [],
"status": "success",
"error": None,
}
try:
flow_parser = FlowParser(flow)
execution_order = flow_parser.get_execution_order()

for parallel_nodes in execution_order:
for node in parallel_nodes:

result = self.run_node(node)
flow_result["nodes_executed"].append(result)

if result["status"] == "failed":
flow_result["status"] = "failed"
flow_result["error"] = result["error"]
break

except Exception as e:
flow_result["status"] = "failed"
flow_result["error"] = str(e)

results.append(flow_result)

return Response({"results": results}, status=status.HTTP_200_OK)

def check_device_connections(self, flow_id):
"""
Checks device connectivity for the flow.
"""
protocol = 'https' if self.request.is_secure() else 'http'
base_url = f"{protocol}://{self.request.get_host()}"
check_devices_url = f"""
{base_url}{reverse('flow-device-connection')}
?{urlencode({'flow_id': flow_id})}
"""
response = requests.get(check_devices_url)

if response.status_code == 200:
connection_data = response.json()
if all(
status['status'] == 'connected'
for status in connection_data['response'].values()
):
return {
"status": "success",
"message": "All devices connected",
}
else:
return {
"status": "failed",
"message": "Some devices are not connected",
"details": connection_data['response'],
}
return {
"status": "error",
"message": "Failed to check device connections",
}

def test_setup(self, flow_id):
"""
Set up and assert device connections for the test flow.
"""
# TODO comment back in when this API is fixed
# self.check_device_connections(flow_id)

# TODO Connect android device(s) in command nodes to nrf kit (LIL-90)

# TODO Assert that connection is setup (LIL-90)

def run_node(self, node):
"""
Execute a single node's function and return the result.
"""
try:
result = None
if node.node_type == Node.ASSERT:
# TODO add code from LIL-91
result = True
elif node.node_type == Node.ACTION:
# TODO add code from LIL-95
result = True
else:
raise ValueError(f"Invalid node type: {node.node_type}")
return {
"node_label": node.label,
"status": "success",
"output": result,
}
except Exception as e:
return {
"node_label": node.label,
"status": "failed",
"error": str(e),
}

0 comments on commit 7c43156

Please sign in to comment.