7
7
import struct
8
8
import sys
9
9
import time
10
+ from pathlib import Path
10
11
11
12
# from Crypto.Cipher import AES
12
13
from rsa_sign import RsaSign
23
24
SIGNED_FILE_EXTENSION = '.ota'
24
25
25
26
OTA_PORT = 8266
26
- SOCKET_TIMEOUT = 10
27
+ SOCKET_TIMEOUT = 0.3
27
28
28
- DNS_SERVER = '8.8.8.8' # Google DNS Server ot get own IP
29
- ENCODING = 'utf-8 '
29
+ # The first OTA package will be send this this broadcast address:
30
+ BROADCAST_ADDRESS = '255.255.255.255 '
30
31
31
32
32
33
def signed_filename (fname ):
@@ -68,13 +69,23 @@ def validate_ota(fname):
68
69
class OtaClient :
69
70
def __init__ (self , fname = None ):
70
71
self .fname = fname
72
+ self .total_size = None
71
73
self .rsa_sign = RsaSign ()
72
74
73
75
self .rsa_key = None
74
76
self .last_aes_key = None
75
77
self .last_seq = 0
76
78
self .rexmit = 0
77
79
80
+ self .sock = socket .socket (socket .AF_INET , socket .SOCK_DGRAM )
81
+ self .sock .setsockopt (socket .SOL_SOCKET , socket .SO_REUSEADDR , 1 )
82
+ self .sock .setsockopt (socket .SOL_SOCKET , socket .SO_BROADCAST , 1 )
83
+ self .sock .settimeout (SOCKET_TIMEOUT )
84
+
85
+ self .device_ip = BROADCAST_ADDRESS # send first packet as broadcast
86
+ self .next_update = 0
87
+ self .start_time = 0
88
+
78
89
def add_digest (self , pkt ):
79
90
aes_key = AES_KEY
80
91
# last_aes_key = aes_key
@@ -98,18 +109,39 @@ def decode_pkt(self, pkt):
98
109
# return aes.decrypt(pkt)
99
110
return pkt
100
111
101
- async def send_recv (self , reader , writer , offset , pkt , data_len ):
102
- while True :
103
- try :
104
- print ('Sending # %d' % self .last_seq )
112
+ def send_recv (self , offset , pkt , data_len ):
113
+
114
+ if self .last_seq == 1 and self .device_ip == BROADCAST_ADDRESS :
115
+ print ('wait for response...' , end = '' )
116
+ elif time .time () > self .next_update :
117
+ duration = time .time () - self .start_time
118
+ sended = offset + data_len
119
+ throughput = sended / duration / 1024
120
+
121
+ percent = 100 / self .total_size * sended
105
122
106
- print ('send:' , pkt )
107
- writer .write (pkt )
108
- await writer .drain ()
123
+ print (f'{ percent :.1f} % Sending #{ self .last_seq } ({ throughput :.1f} KBytes/s)' )
124
+ self .next_update = time .time () + 1
109
125
110
- print ('wait for response...' , end = '' )
111
- resp = await reader .read (1024 )
112
- print ('resp:' , resp , len (resp ))
126
+ while True :
127
+ try :
128
+ self .sock .sendto (pkt , (self .device_ip , OTA_PORT ))
129
+ try :
130
+ resp , server = self .sock .recvfrom (1024 )
131
+ except socket .timeout :
132
+ if self .last_seq == 1 :
133
+ # no device has responded, yet
134
+ if time .time () > self .next_update :
135
+ print ('.' , end = '' , flush = True )
136
+ self .next_update = time .time () + 1
137
+ continue
138
+ else :
139
+ raise
140
+
141
+ if self .start_time == 0 :
142
+ self .start_time = time .time ()
143
+
144
+ # print('resp:', resp, len(resp))
113
145
114
146
resp_seq = struct .unpack ('<I' , resp [:4 ])[0 ]
115
147
if resp_seq != self .last_seq :
@@ -118,45 +150,63 @@ async def send_recv(self, reader, writer, offset, pkt, data_len):
118
150
119
151
resp = resp [4 :]
120
152
resp = self .decode_pkt (resp )
121
- print ('decoded resp:' , resp )
153
+ # print('decoded resp:', resp)
122
154
123
155
resp_op , resp_len , resp_off = struct .unpack ('<HHI' , resp [:8 ])
124
- print ('resp:' , (resp_seq , resp_op , resp_len , resp_off ))
156
+ # print('resp:', (resp_seq, resp_op, resp_len, resp_off))
125
157
126
158
if resp_off != offset or resp_len != data_len :
127
159
print ('Invalid resp' )
128
160
continue
161
+
162
+ if self .device_ip == BROADCAST_ADDRESS :
163
+ # set device IP address and send all next packages to this address
164
+ print ('received from:' , repr (server ))
165
+ self .device_ip = server [0 ]
166
+
129
167
break
130
168
except socket .timeout :
131
- print ('timeout' )
169
+ if time .time () > self .next_update :
170
+ print ('t' , end = '' , flush = True )
171
+ self .next_update = time .time () + 1
172
+
132
173
# For such packets we don't expect reply
133
174
if offset == 0 and data_len == 0 :
134
175
break
135
176
136
177
self .rexmit += 1
137
178
138
- async def send_ota_end (self , writer ):
179
+ def send_ota_end (self ):
139
180
# Repeat few times to minimize chance of being lost
181
+ print ('Send OTA end' , end = '' , flush = True )
140
182
for i in range (3 ):
141
183
pkt = self .make_pkt (0 , b'' )
142
- writer .write (pkt )
143
- await writer .drain ()
184
+ self .sock .sendto (pkt , (self .device_ip , OTA_PORT ))
144
185
time .sleep (0.1 )
186
+ print ('.' , end = '' , flush = True )
187
+
188
+ def live_ota (self ):
189
+ file_path = Path (self .fname )
190
+ self .total_size = file_path .stat ().st_size
145
191
146
- async def live_ota (self , reader , writer ):
147
192
offset = 0
148
- with open (self . fname , 'rb' ) as f :
193
+ with file_path . open ('rb' ) as f :
149
194
while True :
150
195
chunk = f .read (BLK_SIZE )
151
196
if not chunk :
152
197
break
153
198
pkt = self .make_pkt (offset , chunk )
154
- await self .send_recv (reader , writer , offset , pkt , len (chunk ))
199
+ self .send_recv (offset , pkt , len (chunk ))
155
200
offset += len (chunk )
156
201
157
- await self .send_ota_end (writer )
202
+ self .send_ota_end ()
158
203
print ('Done, rexmits: %d' % self .rexmit )
159
204
205
+ duration = time .time () - self .start_time
206
+ throughput = self .total_size / duration / 1024
207
+
208
+ print (f'Send { self .total_size } Bytes in { duration :.1f} sec ({ throughput :.1f} KBytes/s)' )
209
+
160
210
def sign (self , fname ):
161
211
print (f'Sign firmware file { fname } ...' )
162
212
@@ -189,8 +239,11 @@ def hash_write(data):
189
239
190
240
print (f'Signed file created: { out_filename } ' )
191
241
192
- async def canned_ota (self , reader , writer ):
193
- with open (self .fname , 'rb' ) as f_in :
242
+ def canned_ota (self ):
243
+ file_path = Path (self .fname )
244
+ self .total_size = file_path .stat ().st_size
245
+
246
+ with file_path .open ('rb' ) as f_in :
194
247
# Skip signature
195
248
f_in .read (10 )
196
249
while True :
@@ -200,115 +253,14 @@ async def canned_ota(self, reader, writer):
200
253
break
201
254
data = f_in .read (sz )
202
255
last_seq , op , data_len , offset = struct .unpack ('<IHHI' , data [:12 ])
203
- await self .send_recv (reader , writer , offset , data , data_len )
256
+ self .send_recv (offset , data , data_len )
204
257
205
258
print ('Done, rexmits: %d' % self .rexmit )
206
259
260
+ duration = time .time () - self .start_time
261
+ throughput = self .total_size / duration / 1024
207
262
208
- def get_ip_address ():
209
- """
210
- :return: IP address of the host running this script.
211
- """
212
- s = socket .socket (socket .AF_INET , socket .SOCK_DGRAM )
213
- s .settimeout (SOCKET_TIMEOUT )
214
- s .connect ((DNS_SERVER , 80 ))
215
- ip = s .getsockname ()[0 ]
216
- s .close ()
217
- return ip
218
-
219
-
220
- class CommunicationError (RuntimeError ):
221
- pass
222
-
223
-
224
- def ip_range_iterator (own_ip , exclude_own ):
225
- ip_prefix , own_no = own_ip .rsplit ('.' , 1 )
226
- print (f'Scan:.....: { ip_prefix } .X' )
227
-
228
- own_no = int (own_no )
229
-
230
- for no in range (1 , 255 ):
231
- if exclude_own and no == own_no :
232
- continue
233
-
234
- yield f'{ ip_prefix } .{ no } '
235
-
236
-
237
- class OtaStreamWriter (asyncio .StreamWriter ):
238
- encoding = 'utf-8'
239
-
240
- async def write_text_line (self , text ):
241
- self .write (b'%s\n ' % text .encode ('utf-8' ))
242
- await self .drain ()
243
-
244
- async def sendall (self , data ):
245
- self .write (data )
246
- await self .drain ()
247
-
248
-
249
- async def open_connection (host = None , port = None ):
250
- """A wrapper for create_connection() returning a (reader, writer) pair.
251
-
252
- Similar as asyncio.open_connection() but we use own OtaStreamWriter()
253
- """
254
- loop = asyncio .get_event_loop ()
255
- reader = asyncio .StreamReader (limit = 2 ** 16 , loop = loop )
256
- protocol = asyncio .StreamReaderProtocol (reader , loop = loop )
257
- transport , _ = await loop .create_connection (lambda : protocol , host , port )
258
- writer = OtaStreamWriter (transport , protocol , reader , loop )
259
- return reader , writer
260
-
261
-
262
- class AsyncConnector :
263
- """
264
- Scan the own IP range and start callback if receiver found.
265
- """
266
- def __init__ (self , callback ):
267
- self .callback = callback
268
-
269
- async def port_scan_and_serve (self , port ):
270
- own_ip = get_ip_address ()
271
- print (f'Own IP....: { own_ip } ' )
272
- ips = tuple (ip_range_iterator (own_ip , exclude_own = True ))
273
-
274
- print (f'Wait for receivers on port: { port } ' , end = ' ' , flush = True )
275
- clients = []
276
- while True :
277
- connections = [
278
- asyncio .wait_for (open_connection (ip , port ), timeout = 0.5 )
279
- for ip in ips
280
- ]
281
- results = await asyncio .gather (* connections , return_exceptions = True )
282
- for ip , result in zip (ips , results ):
283
- if isinstance (result , asyncio .TimeoutError ):
284
- continue
285
- elif not isinstance (result , tuple ):
286
- # print(result)
287
- continue
288
-
289
- reader , writer = result
290
-
291
- print ('Connected to:' , ip )
292
- peername = writer .get_extra_info ('peername' )
293
- print (f'Connect to { peername [0 ]} :{ peername [1 ]} ' )
294
- try :
295
- await self .callback (reader , writer )
296
- except ConnectionResetError as e :
297
- print (e )
298
- continue
299
- clients .append (ip )
300
-
301
- if clients :
302
- return clients
303
-
304
- print ('.' , end = '' , flush = True )
305
- time .sleep (2 )
306
-
307
- def scan (self , port ):
308
- loop = asyncio .get_event_loop ()
309
- return loop .run_until_complete (
310
- self .port_scan_and_serve (port = port )
311
- )
263
+ print (f'Send { self .total_size } Bytes in { duration :.1f} sec ({ throughput :.1f} KBytes/s)' )
312
264
313
265
314
266
def cli ():
@@ -326,16 +278,12 @@ def cli():
326
278
# Do the OTA update for a device
327
279
validate_ota (args .file )
328
280
329
- AsyncConnector (
330
- callback = OtaClient (args .file ).live_ota
331
- ).scan (port = OTA_PORT )
281
+ OtaClient (args .file ).live_ota ()
332
282
333
283
elif args .command == 'ota' :
334
284
validate_ota (args .file )
335
285
336
- AsyncConnector (
337
- callback = OtaClient (args .file ).canned_ota
338
- ).scan (port = OTA_PORT )
286
+ OtaClient (args .file ).canned_ota ()
339
287
340
288
else :
341
289
cmd_parser .error ('Unknown command' )
0 commit comments