Skip to content

Commit 4dcf79c

Browse files
authored
Merge pull request #3575 from heplesser/fix-3574
Fix data race in NodeManager
2 parents 83713a2 + 02537e7 commit 4dcf79c

File tree

9 files changed

+87
-100
lines changed

9 files changed

+87
-100
lines changed

nestkernel/connection_manager.cpp

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -443,6 +443,8 @@ nest::ConnectionManager::connect( NodeCollectionPTR sources,
443443
const DictionaryDatum& conn_spec,
444444
const std::vector< DictionaryDatum >& syn_specs )
445445
{
446+
kernel().node_manager.update_thread_local_node_data();
447+
446448
if ( sources->empty() )
447449
{
448450
throw IllegalConnection( "Presynaptic nodes cannot be an empty NodeCollection" );
@@ -486,32 +488,6 @@ nest::ConnectionManager::connect( NodeCollectionPTR sources,
486488
}
487489

488490

489-
void
490-
nest::ConnectionManager::connect( TokenArray sources, TokenArray targets, const DictionaryDatum& syn_spec )
491-
{
492-
// Get synapse id
493-
size_t syn_id = 0;
494-
auto synmodel = syn_spec->lookup( names::model );
495-
if ( not synmodel.empty() )
496-
{
497-
const std::string synmodel_name = getValue< std::string >( synmodel );
498-
// The following throws UnknownSynapseType for invalid synmodel_name
499-
syn_id = kernel().model_manager.get_synapse_model_id( synmodel_name );
500-
}
501-
// Connect all sources to all targets
502-
for ( auto&& source : sources )
503-
{
504-
auto source_node = kernel().node_manager.get_node_or_proxy( source );
505-
for ( auto&& target : targets )
506-
{
507-
auto target_node = kernel().node_manager.get_node_or_proxy( target );
508-
auto target_thread = target_node->get_thread();
509-
connect_( *source_node, *target_node, source, target_thread, syn_id, syn_spec );
510-
}
511-
}
512-
}
513-
514-
515491
void
516492
nest::ConnectionManager::update_delay_extrema_()
517493
{
@@ -645,6 +621,8 @@ nest::ConnectionManager::connect_arrays( long* sources,
645621
// only place, where stopwatch sw_construction_connect is needed in addition to nestmodule.cpp
646622
sw_construction_connect.start();
647623

624+
kernel().node_manager.update_thread_local_node_data();
625+
648626
// Mapping pointers to the first parameter value of each parameter to their respective names.
649627
// The bool indicates whether the value is an integer or not, and is determined at a later point.
650628
std::map< Name, std::pair< double*, bool > > param_pointers;
@@ -811,6 +789,8 @@ void
811789
nest::ConnectionManager::connect_sonata( const DictionaryDatum& graph_specs, const long hyberslab_size )
812790
{
813791
#ifdef HAVE_HDF5
792+
kernel().node_manager.update_thread_local_node_data();
793+
814794
SonataConnector sonata_connector( graph_specs, hyberslab_size );
815795

816796
// Set flag before calling sonata_connector.connect() in case exception is thrown after some connections have been
@@ -864,6 +844,8 @@ nest::ConnectionManager::connect_tripartite( NodeCollectionPTR sources,
864844
const std::string primary_rule = static_cast< const std::string >( ( *conn_spec )[ names::rule ] );
865845
const std::string third_rule = static_cast< const std::string >( ( *third_conn_spec )[ names::rule ] );
866846

847+
kernel().node_manager.update_thread_local_node_data();
848+
867849
ConnBuilder cb( primary_rule, third_rule, sources, targets, third, conn_spec, third_conn_spec, syn_specs );
868850

869851
// at this point, all entries in conn_spec and syn_spec have been checked

nestkernel/connection_manager.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -125,8 +125,6 @@ class ConnectionManager : public ManagerInterface
125125
const DictionaryDatum& conn_spec,
126126
const std::vector< DictionaryDatum >& syn_specs );
127127

128-
void connect( TokenArray sources, TokenArray targets, const DictionaryDatum& syn_spec );
129-
130128
/**
131129
* Connect two nodes.
132130
*

nestkernel/nest.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,9 @@ get_connections( const DictionaryDatum& dict )
221221
void
222222
disconnect( const ArrayDatum& conns )
223223
{
224+
// probably not strictly necessary here, but does nothing if all is up to date
225+
kernel().node_manager.update_thread_local_node_data();
226+
224227
for ( size_t conn_index = 0; conn_index < conns.size(); ++conn_index )
225228
{
226229
const auto conn_datum = getValue< ConnectionDatum >( conns.get( conn_index ) );

nestkernel/node_manager.cpp

Lines changed: 43 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ NodeManager::NodeManager()
5050
, node_collection_container_()
5151
, wfr_nodes_vec_()
5252
, wfr_is_used_( false )
53-
, wfr_network_size_( 0 ) // zero to force update
53+
, size_last_local_data_update_( 0 ) // zero to force update
5454
, num_active_nodes_( 0 )
5555
, num_thread_local_devices_()
5656
, have_nodes_changed_( true )
@@ -68,11 +68,11 @@ NodeManager::~NodeManager()
6868
void
6969
NodeManager::initialize( const bool adjust_number_of_threads_or_rng_only )
7070
{
71-
// explicitly force construction of wfr_nodes_vec_ to ensure consistent state
72-
wfr_network_size_ = 0;
71+
// explicitly force construction of thread-local node data to ensure consistent state
72+
size_last_local_data_update_ = 0;
7373
local_nodes_.resize( kernel().vp_manager.get_num_threads() );
7474
num_thread_local_devices_.resize( kernel().vp_manager.get_num_threads(), 0 );
75-
ensure_valid_thread_local_ids();
75+
update_thread_local_node_data();
7676

7777
if ( not adjust_number_of_threads_or_rng_only )
7878
{
@@ -523,74 +523,56 @@ NodeManager::get_thread_siblings( size_t node_id ) const
523523
}
524524

525525
void
526-
NodeManager::ensure_valid_thread_local_ids()
526+
NodeManager::update_thread_local_node_data()
527527
{
528-
// Check if the network size changed, in order to not enter
529-
// the critical region if it is not necessary. Note that this
530-
// test also covers that case that nodes have been deleted
531-
// by reset.
532-
if ( size() == wfr_network_size_ )
528+
kernel().vp_manager.assert_single_threaded();
529+
530+
if ( thread_local_data_is_up_to_date() )
533531
{
534532
return;
535533
}
536534

537-
#pragma omp critical( update_wfr_nodes_vec )
538-
{
539-
// This code may be called from a thread-parallel context, when it is
540-
// invoked by TargetIdentifierIndex::set_target() during parallel
541-
// wiring. Nested OpenMP parallelism is problematic, therefore, we
542-
// enforce single threading here. This should be unproblematic wrt
543-
// performance, because the wfr_nodes_vec_ is rebuilt only once after
544-
// changes in network size.
545-
//
546-
// Check again, if the network size changed, since a previous thread
547-
// can have updated wfr_nodes_vec_ before.
548-
if ( size() != wfr_network_size_ )
549-
{
550-
551-
// We clear the existing wfr_nodes_vec_ and then rebuild it.
552-
wfr_nodes_vec_.clear();
553-
wfr_nodes_vec_.resize( kernel().vp_manager.get_num_threads() );
554-
555-
for ( size_t tid = 0; tid < kernel().vp_manager.get_num_threads(); ++tid )
556-
{
557-
wfr_nodes_vec_[ tid ].clear();
558-
559-
const size_t num_thread_local_wfr_nodes = std::count_if( local_nodes_[ tid ].begin(),
560-
local_nodes_[ tid ].end(),
561-
[]( const SparseNodeArray::NodeEntry& elem ) { return elem.get_node()->node_uses_wfr_; } );
562-
wfr_nodes_vec_[ tid ].reserve( num_thread_local_wfr_nodes );
535+
// We clear the existing wfr_nodes_vec_ and then rebuild it.
536+
wfr_nodes_vec_.clear();
537+
wfr_nodes_vec_.resize( kernel().vp_manager.get_num_threads() );
563538

564-
auto node_it = local_nodes_[ tid ].begin();
565-
size_t idx = 0;
566-
for ( ; node_it < local_nodes_[ tid ].end(); ++node_it, ++idx )
567-
{
568-
auto node = node_it->get_node();
569-
node->set_thread_lid( idx );
570-
if ( node->node_uses_wfr_ )
571-
{
572-
wfr_nodes_vec_[ tid ].push_back( node );
573-
}
574-
}
575-
} // end of for threads
539+
for ( size_t tid = 0; tid < kernel().vp_manager.get_num_threads(); ++tid )
540+
{
541+
wfr_nodes_vec_[ tid ].clear();
576542

577-
wfr_network_size_ = size();
543+
const size_t num_thread_local_wfr_nodes = std::count_if( local_nodes_[ tid ].begin(),
544+
local_nodes_[ tid ].end(),
545+
[]( const SparseNodeArray::NodeEntry& elem ) { return elem.get_node()->node_uses_wfr_; } );
546+
wfr_nodes_vec_[ tid ].reserve( num_thread_local_wfr_nodes );
578547

579-
// wfr_is_used_ indicates, whether at least one
580-
// of the threads has a neuron that uses waveform relaxation
581-
// all threads then need to perform a wfr_update
582-
// step, because gather_events() has to be done in an
583-
// openmp single section
584-
wfr_is_used_ = false;
585-
for ( size_t tid = 0; tid < kernel().vp_manager.get_num_threads(); ++tid )
548+
auto node_it = local_nodes_[ tid ].begin();
549+
size_t idx = 0;
550+
for ( ; node_it < local_nodes_[ tid ].end(); ++node_it, ++idx )
551+
{
552+
auto node = node_it->get_node();
553+
node->set_thread_lid( idx );
554+
if ( node->node_uses_wfr_ )
586555
{
587-
if ( wfr_nodes_vec_[ tid ].size() > 0 )
588-
{
589-
wfr_is_used_ = true;
590-
}
556+
wfr_nodes_vec_[ tid ].push_back( node );
591557
}
592558
}
593-
} // omp critical
559+
} // end of for threads
560+
561+
size_last_local_data_update_ = size();
562+
563+
// wfr_is_used_ indicates, whether at least one
564+
// of the threads has a neuron that uses waveform relaxation
565+
// all threads then need to perform a wfr_update
566+
// step, because gather_events() has to be done in an
567+
// openmp single section
568+
wfr_is_used_ = false;
569+
for ( size_t tid = 0; tid < kernel().vp_manager.get_num_threads(); ++tid )
570+
{
571+
if ( wfr_nodes_vec_[ tid ].size() > 0 )
572+
{
573+
wfr_is_used_ = true;
574+
}
575+
}
594576
}
595577

596578
void

nestkernel/node_manager.h

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -175,13 +175,23 @@ class NodeManager : public ManagerInterface
175175
std::vector< Node* > get_thread_siblings( size_t n ) const;
176176

177177
/**
178-
* Ensure that all nodes in the network have valid thread-local IDs.
178+
* Rebuild per-thread vectors of local nodes and of local nodes needing WFR and set thread-local ID on nodes.
179179
*
180-
* Create up-to-date vector of local nodes, nodes_vec_.
181-
* This method also sets the thread-local ID on all local nodes.
180+
* @note This method must be called from a serial context before connection creation or simulation.
182181
*/
183-
void ensure_valid_thread_local_ids();
182+
void update_thread_local_node_data();
184183

184+
/**
185+
* Return true if thread-local data structures and thread-local node IDs are up to date.
186+
*
187+
* @note The decision is based on whether new nodes have been created since update_thread_local_node_data()
188+
* was run last.
189+
*/
190+
bool thread_local_data_is_up_to_date() const;
191+
192+
/**
193+
* Return node on thread t with given local node id.
194+
*/
185195
Node* thread_lid_to_node( size_t t, targetindex thread_local_id ) const;
186196

187197
/**
@@ -343,9 +353,9 @@ class NodeManager : public ManagerInterface
343353
//!< use the waveform relaxation method
344354
bool wfr_is_used_; //!< there is at least one node that uses
345355
//!< waveform relaxation
346-
//! Network size when wfr_nodes_vec_ was last updated
347-
size_t wfr_network_size_;
348-
size_t num_active_nodes_; //!< number of nodes created by prepare_nodes
356+
357+
size_t size_last_local_data_update_; //! Network size when local node data was last updated
358+
size_t num_active_nodes_; //!< number of nodes created by prepare_nodes
349359

350360
std::vector< size_t > num_thread_local_devices_; //!< stores number of thread local devices
351361

@@ -401,6 +411,15 @@ NodeManager::set_have_nodes_changed( const bool changed )
401411
have_nodes_changed_ = changed;
402412
}
403413

414+
inline bool
415+
NodeManager::thread_local_data_is_up_to_date() const
416+
{
417+
// Our logic assumes that we never delete nodes from a network
418+
assert( size() >= size_last_local_data_update_ );
419+
420+
return size() == size_last_local_data_update_;
421+
}
422+
404423
} // namespace
405424

406425
#endif /* NODE_MANAGER_H */

nestkernel/simulation_manager.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -526,7 +526,7 @@ nest::SimulationManager::prepare()
526526
kernel().event_delivery_manager.configure_spike_data_buffers();
527527
}
528528

529-
kernel().node_manager.ensure_valid_thread_local_ids();
529+
kernel().node_manager.update_thread_local_node_data();
530530
kernel().node_manager.prepare_nodes();
531531

532532
// we have to do enter_runtime after prepare_nodes, since we use

nestkernel/sp_manager.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -242,6 +242,9 @@ SPManager::disconnect( NodeCollectionPTR sources,
242242
DictionaryDatum& conn_spec,
243243
DictionaryDatum& syn_spec )
244244
{
245+
// probably not strictly necessarye here, but does nothing if all is up to date
246+
kernel().node_manager.update_thread_local_node_data();
247+
245248
if ( kernel().connection_manager.connections_have_changed() )
246249
{
247250
#pragma omp parallel

nestkernel/spatial.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -406,6 +406,8 @@ connect_layers( NodeCollectionPTR source_nc, NodeCollectionPTR target_nc, const
406406
ConnectionCreator connector( connection_dict );
407407
ALL_ENTRIES_ACCESSED( *connection_dict, "nest::CreateLayers", "Unread dictionary entries: " );
408408

409+
kernel().node_manager.update_thread_local_node_data();
410+
409411
// Set flag before calling source->connect() in case exception is thrown after some connections have been created.
410412
kernel().connection_manager.set_connections_have_changed();
411413

nestkernel/target_identifier.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,11 +118,9 @@ class TargetIdentifierIndex
118118
{
119119
}
120120

121-
122121
TargetIdentifierIndex( const TargetIdentifierIndex& t ) = default;
123122
TargetIdentifierIndex& operator=( const TargetIdentifierIndex& t ) = default;
124123

125-
126124
void
127125
get_status( DictionaryDatum& d ) const
128126
{
@@ -168,7 +166,7 @@ class TargetIdentifierIndex
168166
inline void
169167
TargetIdentifierIndex::set_target( Node* target )
170168
{
171-
kernel().node_manager.ensure_valid_thread_local_ids();
169+
assert( kernel().node_manager.thread_local_data_is_up_to_date() );
172170

173171
size_t target_lid = target->get_thread_lid();
174172
if ( target_lid > max_targetindex )

0 commit comments

Comments
 (0)