@@ -105,7 +105,7 @@ def save(self, bulk=False, id=None, parent=None, force=False):
105
105
if force :
106
106
version = None
107
107
res = conn .index (self ,
108
- meta .index , meta .type , id , parent = parent , bulk = bulk , version = version , force_insert = force )
108
+ meta .index , meta .type , id , parent = parent , bulk = bulk , version = version , force_insert = force )
109
109
if not bulk :
110
110
self ._meta .id = res ._id
111
111
self ._meta .version = res ._version
@@ -229,6 +229,38 @@ def dict_to_object(self, d):
229
229
return DotDict (d )
230
230
231
231
232
+ class Bulker (object ):
233
+ def __init__ (self , conn , bulk_size = 400 , raise_on_bulk_item_failure = False ):
234
+ self .conn = conn
235
+ self .bulk_size = bulk_size
236
+ # protects bulk_data
237
+ self .bulk_lock = threading .RLock ()
238
+ with self .bulk_lock :
239
+ self .bulk_data = []
240
+ self .raise_on_bulk_item_failure = raise_on_bulk_item_failure
241
+
242
+ def add_to_bulk_queue (self , content ):
243
+ with self .bulk_lock :
244
+ self .bulk_data .append (content )
245
+
246
+ def flush_bulker (self , forced = False ):
247
+ with self .bulk_lock :
248
+ if forced or len (self .bulk_data ) >= self .bulk_size :
249
+ batch = self .bulk_data
250
+ self .bulk_data = []
251
+ else :
252
+ return None
253
+
254
+ if len (batch ) > 0 :
255
+ bulk_result = self .conn ._send_request ("POST" ,
256
+ "/_bulk" ,
257
+ "\n " .join (batch ) + "\n " )
258
+
259
+ if self .raise_on_bulk_item_failure :
260
+ _raise_exception_if_bulk_item_failed (bulk_result )
261
+
262
+ return bulk_result
263
+
232
264
class ES (object ):
233
265
"""
234
266
ES connection object.
@@ -237,7 +269,7 @@ class ES(object):
237
269
encoder = ESJsonEncoder
238
270
decoder = ESJsonDecoder
239
271
240
- def __init__ (self , server = "localhost:9200" , timeout = 5 .0 , bulk_size = 400 ,
272
+ def __init__ (self , server = "localhost:9200" , timeout = 30 .0 , bulk_size = 400 ,
241
273
encoder = None , decoder = None ,
242
274
max_retries = 3 ,
243
275
default_indices = ['_all' ],
@@ -246,7 +278,8 @@ def __init__(self, server="localhost:9200", timeout=5.0, bulk_size=400,
246
278
model = ElasticSearchModel ,
247
279
basic_auth = None ,
248
280
raise_on_bulk_item_failure = False ,
249
- document_object_field = None ):
281
+ document_object_field = None ,
282
+ bulker_class = Bulker ):
250
283
"""
251
284
Init a es object.
252
285
Servers can be defined in different forms:
@@ -286,6 +319,7 @@ def __init__(self, server="localhost:9200", timeout=5.0, bulk_size=400,
286
319
self .connection = None
287
320
self ._mappings = None
288
321
self .document_object_field = document_object_field
322
+ self .bulker_class = bulker_class
289
323
290
324
if model is None :
291
325
model = lambda connection , model : model
@@ -303,11 +337,7 @@ def __init__(self, server="localhost:9200", timeout=5.0, bulk_size=400,
303
337
304
338
#used in bulk
305
339
self .bulk_size = bulk_size #size of the bulk
306
- # protects bulk_data
307
- self .bulk_lock = threading .RLock ()
308
- with self .bulk_lock :
309
- self .bulk_data = []
310
- self .raise_on_bulk_item_failure = raise_on_bulk_item_failure
340
+ self .bulker = bulker_class (self , bulk_size = bulk_size , raise_on_bulk_item_failure = raise_on_bulk_item_failure )
311
341
312
342
if encoder :
313
343
self .encoder = encoder
@@ -333,16 +363,16 @@ def __del__(self):
333
363
Destructor
334
364
"""
335
365
# Don't bother getting the lock
336
- if len ( self .bulk_data ) > 0 :
366
+ if self .bulker :
337
367
# It's not safe to rely on the destructor to flush the queue:
338
368
# the Python documentation explicitly states "It is not guaranteed
339
369
# that __del__() methods are called for objects that still exist "
340
370
# when the interpreter exits."
341
371
logger .error ("pyes object %s is being destroyed, but bulk "
342
372
"operations have not been flushed. Call force_bulk()!" ,
343
- self )
373
+ self )
344
374
# Do our best to save the client anyway...
345
- self .force_bulk ()
375
+ self .bulker . force_bulk ()
346
376
347
377
def _check_servers (self ):
348
378
"""Check the servers variable and convert in a valid tuple form"""
@@ -405,12 +435,17 @@ def _init_connection(self):
405
435
if _type in ["http" , "https" ]:
406
436
self .connection = http_connect (
407
437
[(_type , host , port ) for _type , host , port in self .servers if _type in ["http" , "https" ]],
408
- timeout = self .timeout , basic_auth = self .basic_auth ,
409
- max_retries = self .max_retries )
438
+ timeout = self .timeout
439
+ ,
440
+ basic_auth = self .basic_auth
441
+ ,
442
+ max_retries = self .max_retries )
410
443
return
411
444
elif _type == "thrift" :
412
445
self .connection = thrift_connect ([(host , port ) for _type , host , port in self .servers if _type == "thrift" ],
413
- timeout = self .timeout , max_retries = self .max_retries )
446
+ timeout = self .timeout
447
+ ,
448
+ max_retries = self .max_retries )
414
449
415
450
def _discovery (self ):
416
451
"""
@@ -444,7 +479,7 @@ def _send_request(self, method, path, body=None, params=None, headers=None, raw=
444
479
else :
445
480
body = ""
446
481
request = RestRequest (method = Method ._NAMES_TO_VALUES [method .upper ()],
447
- uri = path , parameters = params , headers = headers , body = body )
482
+ uri = path , parameters = params , headers = headers , body = body )
448
483
if self .dump_curl is not None :
449
484
self ._dump_curl_request (request )
450
485
@@ -536,8 +571,8 @@ def _set_default_indices(self, default_indices):
536
571
@property
537
572
def mappings (self ):
538
573
if self ._mappings is None :
539
- self ._mappings = Mapper (self .get_mapping (["_all" ]), connection = self ,
540
- document_object_field = self .document_object_field )
574
+ self ._mappings = Mapper (self .get_mapping (indices = ["_all" ]), connection = self ,
575
+ document_object_field = self .document_object_field )
541
576
return self ._mappings
542
577
543
578
#---- Admin commands
@@ -799,7 +834,7 @@ def optimize(self, indices=None,
799
834
only_expunge_deletes = only_expunge_deletes ,
800
835
refresh = refresh ,
801
836
flush = flush ,
802
- )
837
+ )
803
838
if max_num_segments is not None :
804
839
params ['max_num_segments' ] = max_num_segments
805
840
result = self ._send_request ('POST' , path , params = params )
@@ -974,17 +1009,14 @@ def cluster_stats(self, nodes=None):
974
1009
path = self ._make_path (parts )
975
1010
return self ._send_request ('GET' , path )
976
1011
977
- def _add_to_bulk_queue (self , content ):
978
- with self .bulk_lock :
979
- self .bulk_data .append (content )
980
1012
981
1013
def index_raw_bulk (self , header , document ):
982
1014
"""
983
1015
Function helper for fast inserting
984
1016
985
1017
header and document must be string "\n " ended
986
1018
"""
987
- self ._add_to_bulk_queue (u"%s%s" % (header , document ))
1019
+ self .bulker . add_to_bulk_queue (u"%s%s" % (header , document ))
988
1020
return self .flush_bulk ()
989
1021
990
1022
def index (self , doc , index , doc_type , id = None , parent = None ,
@@ -1018,7 +1050,7 @@ def index(self, doc, index, doc_type, id=None, parent=None,
1018
1050
if isinstance (doc , dict ):
1019
1051
doc = json .dumps (doc , cls = self .encoder )
1020
1052
command = "%s\n %s" % (json .dumps (cmd , cls = self .encoder ), doc )
1021
- self ._add_to_bulk_queue (command )
1053
+ self .bulker . add_to_bulk_queue (command )
1022
1054
return self .flush_bulk ()
1023
1055
1024
1056
if force_insert :
@@ -1063,22 +1095,7 @@ def flush_bulk(self, forced=False):
1063
1095
"""
1064
1096
Send pending operations if forced or if the bulk threshold is exceeded.
1065
1097
"""
1066
- with self .bulk_lock :
1067
- if forced or len (self .bulk_data ) >= self .bulk_size :
1068
- batch = self .bulk_data
1069
- self .bulk_data = []
1070
- else :
1071
- return None
1072
-
1073
- if len (batch ) > 0 :
1074
- bulk_result = self ._send_request ("POST" ,
1075
- "/_bulk" ,
1076
- "\n " .join (batch ) + "\n " )
1077
-
1078
- if self .raise_on_bulk_item_failure :
1079
- _raise_exception_if_bulk_item_failed (bulk_result )
1080
-
1081
- return bulk_result
1098
+ self .bulker .flush_bulk (forced )
1082
1099
1083
1100
def force_bulk (self ):
1084
1101
"""
@@ -1139,7 +1156,7 @@ def update(self, extra_doc, index, doc_type, id, querystring_args=None,
1139
1156
new_doc = current_doc
1140
1157
try :
1141
1158
return self .index (new_doc , index , doc_type , id ,
1142
- version = current_doc ._meta .version , querystring_args = querystring_args )
1159
+ version = current_doc ._meta .version , querystring_args = querystring_args )
1143
1160
except VersionConflictEngineException :
1144
1161
if attempt <= 0 :
1145
1162
raise
@@ -1154,7 +1171,7 @@ def delete(self, index, doc_type, id, bulk=False, querystring_args=None):
1154
1171
if bulk :
1155
1172
cmd = {"delete" : {"_index" : index , "_type" : doc_type ,
1156
1173
"_id" : id }}
1157
- self ._add_to_bulk_queue (json .dumps (cmd , cls = self .encoder ))
1174
+ self .bulker . add_to_bulk_queue (json .dumps (cmd , cls = self .encoder ))
1158
1175
return self .flush_bulk ()
1159
1176
1160
1177
path = self ._make_path ([index , doc_type , id ])
@@ -1258,8 +1275,8 @@ def mget(self, ids, index=None, doc_type=None, routing=None, **get_params):
1258
1275
if routing :
1259
1276
get_params ["routing" ] = routing
1260
1277
results = self ._send_request ('GET' , "/_mget" ,
1261
- body = {'docs' : body },
1262
- params = get_params )
1278
+ body = {'docs' : body },
1279
+ params = get_params )
1263
1280
if 'docs' in results :
1264
1281
model = self .model
1265
1282
return [model (self , item ) for item in results ['docs' ]]
@@ -1369,7 +1386,7 @@ def count(self, query=None, indices=None, doc_types=None, **query_params):
1369
1386
if doc_types is None :
1370
1387
doc_types = []
1371
1388
if query is None :
1372
- from .. query import MatchAllQuery
1389
+ from .query import MatchAllQuery
1373
1390
1374
1391
query = MatchAllQuery ()
1375
1392
if hasattr (query , 'to_query_json' ):
@@ -1542,7 +1559,7 @@ def _do_search(self, auto_increment=False):
1542
1559
self .query .size = self .chuck_size
1543
1560
1544
1561
self ._results = self .connection .search_raw (self .query , indices = self .indices ,
1545
- doc_types = self .doc_types , ** self .query_params )
1562
+ doc_types = self .doc_types , ** self .query_params )
1546
1563
if 'search_type' in self .query_params and self .query_params ['search_type' ] == "scan" :
1547
1564
self .scroller_parameters ['search_type' ] = self .query_params ['search_type' ]
1548
1565
del self .query_params ['search_type' ]
@@ -1561,7 +1578,7 @@ def _do_search(self, auto_increment=False):
1561
1578
else :
1562
1579
try :
1563
1580
self ._results = self .connection .search_scroll (self .scroller_id ,
1564
- self .scroller_parameters .get ("scroll" , "10m" ))
1581
+ self .scroller_parameters .get ("scroll" , "10m" ))
1565
1582
self .scroller_id = self ._results ['_scroll_id' ]
1566
1583
except ReduceSearchPhaseException :
1567
1584
#bad hack, should be not hits on the last iteration
@@ -1670,7 +1687,7 @@ def get_start_end(val):
1670
1687
query ['size' ] = end - start
1671
1688
1672
1689
results = self .connection .search_raw (query , indices = self .indices ,
1673
- doc_types = self .doc_types , ** self .query_params )
1690
+ doc_types = self .doc_types , ** self .query_params )
1674
1691
1675
1692
hits = results ['hits' ]['hits' ]
1676
1693
if not isinstance (val , slice ):
0 commit comments