4
4
5
5
6
6
import socket
7
+ import threading
7
8
8
9
import utils
9
10
@@ -14,6 +15,7 @@ class SSLConnection(object):
14
15
BIO_CLOSE = 1
15
16
16
17
def __init__ (self , context , sock , ip_str = None , sni = None , on_close = None ):
18
+ self ._lock = threading .Lock ()
17
19
self ._context = context
18
20
self ._sock = sock
19
21
self .ip_str = utils .to_bytes (ip_str )
@@ -66,6 +68,9 @@ def wrap(self):
66
68
raise socket .error ("SSL_connect fail: %s" % error )
67
69
68
70
def do_handshake (self ):
71
+ if not self ._connection :
72
+ raise socket .error ("do_handshake fail: not connected" )
73
+
69
74
ret = bssl .SSL_do_handshake (self ._connection )
70
75
if ret == 1 :
71
76
return
@@ -74,6 +79,9 @@ def do_handshake(self):
74
79
raise socket .error ("do_handshake fail: %s" % error )
75
80
76
81
def is_support_h2 (self ):
82
+ if not self ._connection :
83
+ return False
84
+
77
85
out_data_pp = ffi .new ("uint8_t**" , ffi .NULL )
78
86
out_len_p = ffi .new ("unsigned*" )
79
87
bssl .SSL_get0_alpn_selected (self ._connection , out_data_pp , out_len_p )
@@ -90,21 +98,13 @@ def setblocking(self, block):
90
98
self ._sock .setblocking (block )
91
99
92
100
def __getattr__ (self , attr ):
93
- if attr == "socket_closed" :
94
- # work around in case close before finished init.
95
- return True
96
-
97
- elif attr in ('is_support_h2' , "_on_close" , '_context' , '_sock' , '_connection' , '_makefile_refs' ,
101
+ if attr in ('is_support_h2' , "_on_close" , '_context' , '_sock' , '_connection' , '_makefile_refs' ,
98
102
'sni' , 'wrap' , 'socket_closed' ):
99
103
return getattr (self , attr )
100
104
101
105
elif hasattr (self ._connection , attr ):
102
106
return getattr (self ._connection , attr )
103
107
104
- def __del__ (self ):
105
- if not self .socket_closed and self ._connection :
106
- self .close ()
107
-
108
108
def get_cert (self ):
109
109
if self .peer_cert :
110
110
return self .peer_cert
@@ -113,25 +113,27 @@ def x509_name_to_string(xname):
113
113
line = bssl .X509_NAME_oneline (xname , ffi .NULL , 0 )
114
114
return ffi .string (line )
115
115
116
- try :
117
- cert = bssl .SSL_get_peer_certificate (self ._connection )
118
- if cert == ffi .NULL :
119
- raise Exception ("get cert failed" )
116
+ with self ._lock :
117
+ if self ._connection :
118
+ try :
119
+ cert = bssl .SSL_get_peer_certificate (self ._connection )
120
+ if cert == ffi .NULL :
121
+ raise Exception ("get cert failed" )
120
122
121
- alt_names_p = bssl .get_alt_names (cert )
122
- if alt_names_p == ffi .NULL :
123
- raise Exception ("get alt_names failed" )
123
+ alt_names_p = bssl .get_alt_names (cert )
124
+ if alt_names_p == ffi .NULL :
125
+ raise Exception ("get alt_names failed" )
124
126
125
- alt_names = utils .to_str (ffi .string (alt_names_p ))
126
- bssl .free (alt_names_p )
127
+ alt_names = utils .to_str (ffi .string (alt_names_p ))
128
+ bssl .free (alt_names_p )
127
129
128
- subject = x509_name_to_string (bssl .X509_get_subject_name (cert ))
129
- issuer = x509_name_to_string (bssl .X509_get_issuer_name (cert ))
130
- altName = alt_names .split (";" )
131
- except Exception as e :
132
- subject = ""
133
- issuer = ""
134
- altName = []
130
+ subject = x509_name_to_string (bssl .X509_get_subject_name (cert ))
131
+ issuer = x509_name_to_string (bssl .X509_get_issuer_name (cert ))
132
+ altName = alt_names .split (";" )
133
+ except Exception as e :
134
+ subject = ""
135
+ issuer = ""
136
+ altName = []
135
137
136
138
self .peer_cert = {
137
139
"cert" : subject ,
@@ -143,40 +145,66 @@ def x509_name_to_string(xname):
143
145
return self .peer_cert
144
146
145
147
def send (self , data , flags = 0 ):
146
- try :
147
- ret = bssl .SSL_write (self ._connection , data , len (data ))
148
- return ret
149
- except Exception as e :
150
- self ._context .logger .exception ("ssl send:%r" , e )
151
- raise e
148
+ with self ._lock :
149
+ if not self ._connection :
150
+ e = socket .error (5 )
151
+ e .errno = 5
152
+ raise e
153
+
154
+ try :
155
+ ret = bssl .SSL_write (self ._connection , data , len (data ))
156
+ if ret <= 0 :
157
+ errno = bssl .SSL_get_error (self ._connection , ret )
158
+ self ._context .logger .warn ("send n:%d errno: %d ip:%s" , ret , errno , self .ip_str )
159
+ e = socket .error (2 )
160
+ e .errno = errno
161
+ raise e
162
+
163
+ return ret
164
+ except Exception as e :
165
+ self ._context .logger .exception ("ssl send:%r" , e )
166
+ raise e
152
167
153
168
def recv (self , bufsiz , flags = 0 ):
154
- buf = bytes (bufsiz )
155
- n = bssl .SSL_read (self ._connection , buf , bufsiz )
156
- if n <= 0 :
157
- errno = bssl .SSL_get_error (self ._connection , n )
158
- self ._context .logger .warn ("recv errno: %d ip:%s" , errno , self .ip_str )
159
- e = socket .error (2 )
160
- e .errno = errno
161
- raise e
162
-
163
- dat = buf [:n ]
164
- return dat
169
+ with self ._lock :
170
+ if not self ._connection :
171
+ e = socket .error (2 )
172
+ e .errno = 5
173
+ raise e
174
+
175
+ buf = bytes (bufsiz )
176
+ n = bssl .SSL_read (self ._connection , buf , bufsiz )
177
+ if n <= 0 :
178
+ errno = bssl .SSL_get_error (self ._connection , n )
179
+ self ._context .logger .warn ("recv n:%d errno: %d ip:%s" , n , errno , self .ip_str )
180
+ e = socket .error (2 )
181
+ e .errno = errno
182
+ raise e
183
+
184
+ dat = buf [:n ]
185
+ self ._context .logger .debug ("recv %d" , n )
186
+ return dat
165
187
166
188
def recv_into (self , buf , nbytes = None ):
167
- if not nbytes :
168
- nbytes = len (buf )
169
-
170
- b = ffi .from_buffer (buf )
171
- n = bssl .SSL_read (self ._connection , b , nbytes )
172
- if n <= 0 :
173
- errno = bssl .SSL_get_error (self ._connection , n )
174
- self ._context .logger .warn ("recv_into errno: %d ip:%s" , errno , self .ip_str )
175
- e = socket .error (2 )
176
- e .errno = errno
177
- raise e
178
-
179
- return n
189
+ with self ._lock :
190
+ if not self ._connection :
191
+ e = socket .error (2 )
192
+ e .errno = 5
193
+ raise e
194
+
195
+ if not nbytes :
196
+ nbytes = len (buf )
197
+ buf_new = bytes (nbytes )
198
+
199
+ n = bssl .SSL_read (self ._connection , buf_new , nbytes )
200
+ if n <= 0 :
201
+ errno = bssl .SSL_get_error (self ._connection , n )
202
+ e = socket .error (2 )
203
+ e .errno = errno
204
+ raise e
205
+
206
+ buf [:n ] = buf_new [:n ]
207
+ return n
180
208
181
209
def read (self , bufsiz , flags = 0 ):
182
210
return self .recv (bufsiz , flags )
@@ -185,27 +213,29 @@ def write(self, buf, flags=0):
185
213
return self .send (buf , flags )
186
214
187
215
def close (self ):
188
- if self ._makefile_refs < 1 :
216
+ with self ._lock :
189
217
self .running = False
190
218
if not self .socket_closed :
219
+ if self ._connection :
220
+ bssl .SSL_shutdown (self ._connection )
191
221
192
- bssl .SSL_shutdown (self ._connection )
193
- bssl .SSL_free (self ._connection )
194
- self ._connection = None
195
-
196
- self ._sock = None
197
222
self .socket_closed = True
198
223
if self ._on_close :
199
224
self ._on_close (self .ip_str )
200
- else :
201
- self ._makefile_refs -= 1
225
+
226
+ def __del__ (self ):
227
+ self .close ()
228
+ if self ._connection :
229
+ bssl .SSL_free (self ._connection )
230
+ self ._connection = None
231
+ self ._sock = None
202
232
203
233
def settimeout (self , t ):
204
234
if not self .running :
205
235
return
206
236
207
237
if self .timeout != t :
208
- # self._sock.settimeout(t)
238
+ self ._sock .settimeout (t )
209
239
self .timeout = t
210
240
211
241
def makefile (self , mode = 'r' , bufsize = - 1 ):
0 commit comments