Skip to content

Commit 459009c

Browse files
committed
Fix #1 by replace asyncio TCP AsyncConnector with UDP broadcast solution
1 parent 06cfb3c commit 459009c

File tree

1 file changed

+84
-136
lines changed

1 file changed

+84
-136
lines changed

ota-client/ota_client.py

Lines changed: 84 additions & 136 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import struct
88
import sys
99
import time
10+
from pathlib import Path
1011

1112
# from Crypto.Cipher import AES
1213
from rsa_sign import RsaSign
@@ -23,10 +24,10 @@
2324
SIGNED_FILE_EXTENSION = '.ota'
2425

2526
OTA_PORT = 8266
26-
SOCKET_TIMEOUT = 10
27+
SOCKET_TIMEOUT = 0.3
2728

28-
DNS_SERVER = '8.8.8.8' # Google DNS Server ot get own IP
29-
ENCODING = 'utf-8'
29+
# The first OTA package will be send this this broadcast address:
30+
BROADCAST_ADDRESS = '255.255.255.255'
3031

3132

3233
def signed_filename(fname):
@@ -68,13 +69,23 @@ def validate_ota(fname):
6869
class OtaClient:
6970
def __init__(self, fname=None):
7071
self.fname = fname
72+
self.total_size = None
7173
self.rsa_sign = RsaSign()
7274

7375
self.rsa_key = None
7476
self.last_aes_key = None
7577
self.last_seq = 0
7678
self.rexmit = 0
7779

80+
self.sock = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
81+
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
82+
self.sock.setsockopt(socket.SOL_SOCKET, socket.SO_BROADCAST, 1)
83+
self.sock.settimeout(SOCKET_TIMEOUT)
84+
85+
self.device_ip = BROADCAST_ADDRESS # send first packet as broadcast
86+
self.next_update = 0
87+
self.start_time = 0
88+
7889
def add_digest(self, pkt):
7990
aes_key = AES_KEY
8091
# last_aes_key = aes_key
@@ -98,18 +109,39 @@ def decode_pkt(self, pkt):
98109
# return aes.decrypt(pkt)
99110
return pkt
100111

101-
async def send_recv(self, reader, writer, offset, pkt, data_len):
102-
while True:
103-
try:
104-
print('Sending # %d' % self.last_seq)
112+
def send_recv(self, offset, pkt, data_len):
113+
114+
if self.last_seq == 1 and self.device_ip == BROADCAST_ADDRESS:
115+
print('wait for response...', end='')
116+
elif time.time() > self.next_update:
117+
duration = time.time() - self.start_time
118+
sended = offset + data_len
119+
throughput = sended / duration / 1024
120+
121+
percent = 100 / self.total_size * sended
105122

106-
print('send:', pkt)
107-
writer.write(pkt)
108-
await writer.drain()
123+
print(f'{percent:.1f}% Sending #{self.last_seq} ({throughput:.1f} KBytes/s)')
124+
self.next_update = time.time() + 1
109125

110-
print('wait for response...', end='')
111-
resp = await reader.read(1024)
112-
print('resp:', resp, len(resp))
126+
while True:
127+
try:
128+
self.sock.sendto(pkt, (self.device_ip, OTA_PORT))
129+
try:
130+
resp, server = self.sock.recvfrom(1024)
131+
except socket.timeout:
132+
if self.last_seq == 1:
133+
# no device has responded, yet
134+
if time.time() > self.next_update:
135+
print('.', end='', flush=True)
136+
self.next_update = time.time() + 1
137+
continue
138+
else:
139+
raise
140+
141+
if self.start_time == 0:
142+
self.start_time = time.time()
143+
144+
# print('resp:', resp, len(resp))
113145

114146
resp_seq = struct.unpack('<I', resp[:4])[0]
115147
if resp_seq != self.last_seq:
@@ -118,45 +150,63 @@ async def send_recv(self, reader, writer, offset, pkt, data_len):
118150

119151
resp = resp[4:]
120152
resp = self.decode_pkt(resp)
121-
print('decoded resp:', resp)
153+
# print('decoded resp:', resp)
122154

123155
resp_op, resp_len, resp_off = struct.unpack('<HHI', resp[:8])
124-
print('resp:', (resp_seq, resp_op, resp_len, resp_off))
156+
# print('resp:', (resp_seq, resp_op, resp_len, resp_off))
125157

126158
if resp_off != offset or resp_len != data_len:
127159
print('Invalid resp')
128160
continue
161+
162+
if self.device_ip == BROADCAST_ADDRESS:
163+
# set device IP address and send all next packages to this address
164+
print('received from:', repr(server))
165+
self.device_ip = server[0]
166+
129167
break
130168
except socket.timeout:
131-
print('timeout')
169+
if time.time() > self.next_update:
170+
print('t', end='', flush=True)
171+
self.next_update = time.time() + 1
172+
132173
# For such packets we don't expect reply
133174
if offset == 0 and data_len == 0:
134175
break
135176

136177
self.rexmit += 1
137178

138-
async def send_ota_end(self, writer):
179+
def send_ota_end(self):
139180
# Repeat few times to minimize chance of being lost
181+
print('Send OTA end', end='', flush=True)
140182
for i in range(3):
141183
pkt = self.make_pkt(0, b'')
142-
writer.write(pkt)
143-
await writer.drain()
184+
self.sock.sendto(pkt, (self.device_ip, OTA_PORT))
144185
time.sleep(0.1)
186+
print('.', end='', flush=True)
187+
188+
def live_ota(self):
189+
file_path = Path(self.fname)
190+
self.total_size = file_path.stat().st_size
145191

146-
async def live_ota(self, reader, writer):
147192
offset = 0
148-
with open(self.fname, 'rb') as f:
193+
with file_path.open('rb') as f:
149194
while True:
150195
chunk = f.read(BLK_SIZE)
151196
if not chunk:
152197
break
153198
pkt = self.make_pkt(offset, chunk)
154-
await self.send_recv(reader, writer, offset, pkt, len(chunk))
199+
self.send_recv(offset, pkt, len(chunk))
155200
offset += len(chunk)
156201

157-
await self.send_ota_end(writer)
202+
self.send_ota_end()
158203
print('Done, rexmits: %d' % self.rexmit)
159204

205+
duration = time.time() - self.start_time
206+
throughput = self.total_size / duration / 1024
207+
208+
print(f'Send {self.total_size} Bytes in {duration:.1f}sec ({throughput:.1f} KBytes/s)')
209+
160210
def sign(self, fname):
161211
print(f'Sign firmware file {fname}...')
162212

@@ -189,8 +239,11 @@ def hash_write(data):
189239

190240
print(f'Signed file created: {out_filename}')
191241

192-
async def canned_ota(self, reader, writer):
193-
with open(self.fname, 'rb') as f_in:
242+
def canned_ota(self):
243+
file_path = Path(self.fname)
244+
self.total_size = file_path.stat().st_size
245+
246+
with file_path.open('rb') as f_in:
194247
# Skip signature
195248
f_in.read(10)
196249
while True:
@@ -200,115 +253,14 @@ async def canned_ota(self, reader, writer):
200253
break
201254
data = f_in.read(sz)
202255
last_seq, op, data_len, offset = struct.unpack('<IHHI', data[:12])
203-
await self.send_recv(reader, writer, offset, data, data_len)
256+
self.send_recv(offset, data, data_len)
204257

205258
print('Done, rexmits: %d' % self.rexmit)
206259

260+
duration = time.time() - self.start_time
261+
throughput = self.total_size / duration / 1024
207262

208-
def get_ip_address():
209-
"""
210-
:return: IP address of the host running this script.
211-
"""
212-
s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
213-
s.settimeout(SOCKET_TIMEOUT)
214-
s.connect((DNS_SERVER, 80))
215-
ip = s.getsockname()[0]
216-
s.close()
217-
return ip
218-
219-
220-
class CommunicationError(RuntimeError):
221-
pass
222-
223-
224-
def ip_range_iterator(own_ip, exclude_own):
225-
ip_prefix, own_no = own_ip.rsplit('.', 1)
226-
print(f'Scan:.....: {ip_prefix}.X')
227-
228-
own_no = int(own_no)
229-
230-
for no in range(1, 255):
231-
if exclude_own and no == own_no:
232-
continue
233-
234-
yield f'{ip_prefix}.{no}'
235-
236-
237-
class OtaStreamWriter(asyncio.StreamWriter):
238-
encoding = 'utf-8'
239-
240-
async def write_text_line(self, text):
241-
self.write(b'%s\n' % text.encode('utf-8'))
242-
await self.drain()
243-
244-
async def sendall(self, data):
245-
self.write(data)
246-
await self.drain()
247-
248-
249-
async def open_connection(host=None, port=None):
250-
"""A wrapper for create_connection() returning a (reader, writer) pair.
251-
252-
Similar as asyncio.open_connection() but we use own OtaStreamWriter()
253-
"""
254-
loop = asyncio.get_event_loop()
255-
reader = asyncio.StreamReader(limit=2 ** 16, loop=loop)
256-
protocol = asyncio.StreamReaderProtocol(reader, loop=loop)
257-
transport, _ = await loop.create_connection(lambda: protocol, host, port)
258-
writer = OtaStreamWriter(transport, protocol, reader, loop)
259-
return reader, writer
260-
261-
262-
class AsyncConnector:
263-
"""
264-
Scan the own IP range and start callback if receiver found.
265-
"""
266-
def __init__(self, callback):
267-
self.callback = callback
268-
269-
async def port_scan_and_serve(self, port):
270-
own_ip = get_ip_address()
271-
print(f'Own IP....: {own_ip}')
272-
ips = tuple(ip_range_iterator(own_ip, exclude_own=True))
273-
274-
print(f'Wait for receivers on port: {port}', end=' ', flush=True)
275-
clients = []
276-
while True:
277-
connections = [
278-
asyncio.wait_for(open_connection(ip, port), timeout=0.5)
279-
for ip in ips
280-
]
281-
results = await asyncio.gather(*connections, return_exceptions=True)
282-
for ip, result in zip(ips, results):
283-
if isinstance(result, asyncio.TimeoutError):
284-
continue
285-
elif not isinstance(result, tuple):
286-
# print(result)
287-
continue
288-
289-
reader, writer = result
290-
291-
print('Connected to:', ip)
292-
peername = writer.get_extra_info('peername')
293-
print(f'Connect to {peername[0]}:{peername[1]}')
294-
try:
295-
await self.callback(reader, writer)
296-
except ConnectionResetError as e:
297-
print(e)
298-
continue
299-
clients.append(ip)
300-
301-
if clients:
302-
return clients
303-
304-
print('.', end='', flush=True)
305-
time.sleep(2)
306-
307-
def scan(self, port):
308-
loop = asyncio.get_event_loop()
309-
return loop.run_until_complete(
310-
self.port_scan_and_serve(port=port)
311-
)
263+
print(f'Send {self.total_size} Bytes in {duration:.1f}sec ({throughput:.1f} KBytes/s)')
312264

313265

314266
def cli():
@@ -326,16 +278,12 @@ def cli():
326278
# Do the OTA update for a device
327279
validate_ota(args.file)
328280

329-
AsyncConnector(
330-
callback=OtaClient(args.file).live_ota
331-
).scan(port=OTA_PORT)
281+
OtaClient(args.file).live_ota()
332282

333283
elif args.command == 'ota':
334284
validate_ota(args.file)
335285

336-
AsyncConnector(
337-
callback=OtaClient(args.file).canned_ota
338-
).scan(port=OTA_PORT)
286+
OtaClient(args.file).canned_ota()
339287

340288
else:
341289
cmd_parser.error('Unknown command')

0 commit comments

Comments
 (0)