Parallelize ManyToMany plugin

This commit is contained in:
Michael Krasnyk 2017-08-29 20:02:04 +02:00
parent 543048efcc
commit 172a8bdcdb

View File

@ -25,26 +25,25 @@ struct NodeBucket
EdgeWeight weight; EdgeWeight weight;
EdgeDuration duration; EdgeDuration duration;
NodeBucket(NodeID middle_node, unsigned column_index, EdgeWeight weight, EdgeDuration duration) NodeBucket(NodeID middle_node, unsigned column_index, EdgeWeight weight, EdgeDuration duration)
: middle_node(middle_node), column_index(column_index), weight(weight), duration(duration) {} : middle_node(middle_node), column_index(column_index), weight(weight), duration(duration)
{
}
// partial order comparison // partial order comparison
bool operator<(const NodeBucket& rhs) const bool operator<(const NodeBucket &rhs) const { return middle_node < rhs.middle_node; }
{
return middle_node < rhs.middle_node;
}
// functor for equal_range // functor for equal_range
struct Compare struct Compare
{ {
bool operator() (const NodeBucket &lhs, const NodeID& rhs) const bool operator()(const NodeBucket &lhs, const NodeID &rhs) const
{ {
return lhs.middle_node < rhs; return lhs.middle_node < rhs;
} }
bool operator() (const NodeID &lhs, const NodeBucket& rhs) const bool operator()(const NodeID &lhs, const NodeBucket &rhs) const
{ {
return lhs < rhs.middle_node; return lhs < rhs.middle_node;
} }
}; };
}; };
@ -257,7 +256,10 @@ void forwardRoutingStep(const DataFacade<Algorithm> &facade,
const auto source_duration = query_heap.GetData(node).duration; const auto source_duration = query_heap.GetData(node).duration;
// check if each encountered node has an entry // check if each encountered node has an entry
const auto &bucket_list = std::equal_range(search_space_with_buckets.begin(), search_space_with_buckets.end(), node, NodeBucket::Compare()); const auto &bucket_list = std::equal_range(search_space_with_buckets.begin(),
search_space_with_buckets.end(),
node,
NodeBucket::Compare());
for (const auto &current_bucket : boost::make_iterator_range(bucket_list)) for (const auto &current_bucket : boost::make_iterator_range(bucket_list))
{ {
// get target id from bucket entry // get target id from bucket entry
@ -335,52 +337,71 @@ std::vector<EdgeWeight> manyToManySearch(SearchEngineData<Algorithm> &engine_wor
std::vector<EdgeWeight> weights_table(number_of_entries, INVALID_EDGE_WEIGHT); std::vector<EdgeWeight> weights_table(number_of_entries, INVALID_EDGE_WEIGHT);
std::vector<EdgeDuration> durations_table(number_of_entries, MAXIMAL_EDGE_DURATION); std::vector<EdgeDuration> durations_table(number_of_entries, MAXIMAL_EDGE_DURATION);
std::mutex lock;
std::vector<NodeBucket> search_space_with_buckets; std::vector<NodeBucket> search_space_with_buckets;
engine_working_data.InitializeOrClearManyToManyThreadLocalStorage(facade.GetNumberOfNodes());
auto &query_heap = *(engine_working_data.many_to_many_heap);
// Backward search for target phantoms // Backward search for target phantoms
for (const auto column_idx : util::irange<std::size_t>(0u, target_indices.size())) tbb::parallel_for(
{ tbb::blocked_range<std::size_t>{0, target_indices.size()},
const auto index = target_indices[column_idx]; [&](const auto &chunk) {
const auto &phantom = phantom_nodes[index]; for (auto column_idx = chunk.begin(), end = chunk.end(); column_idx != end;
++column_idx)
{
const auto index = target_indices[column_idx];
const auto &phantom = phantom_nodes[index];
query_heap.Clear(); engine_working_data.InitializeOrClearManyToManyThreadLocalStorage(
insertTargetInHeap(query_heap, phantom); facade.GetNumberOfNodes());
auto &query_heap = *(engine_working_data.many_to_many_heap);
insertTargetInHeap(query_heap, phantom);
// explore search space // explore search space
while (!query_heap.Empty()) std::vector<NodeBucket> local_buckets;
{ while (!query_heap.Empty())
backwardRoutingStep(facade, column_idx, query_heap, search_space_with_buckets, phantom); {
} backwardRoutingStep(facade, column_idx, query_heap, local_buckets, phantom);
} }
std::sort(search_space_with_buckets.begin(), search_space_with_buckets.end()); { // Insert local buckets into the global search space
std::lock_guard<std::mutex> guard{lock};
search_space_with_buckets.insert(std::end(search_space_with_buckets),
std::begin(local_buckets),
std::end(local_buckets));
}
}
});
tbb::parallel_sort(search_space_with_buckets.begin(), search_space_with_buckets.end());
// For each source do forward search // For each source do forward search
for (const auto row_idx : util::irange<std::size_t>(0, source_indices.size())) tbb::parallel_for(tbb::blocked_range<std::size_t>{0, source_indices.size()},
{ [&](const auto &chunk) {
const auto index = source_indices[row_idx]; for (auto row_idx = chunk.begin(), end = chunk.end(); row_idx != end;
const auto &phantom = phantom_nodes[index]; ++row_idx)
{
const auto index = source_indices[row_idx];
const auto &phantom = phantom_nodes[index];
// clear heap and insert source nodes // clear heap and insert source nodes
query_heap.Clear(); engine_working_data.InitializeOrClearManyToManyThreadLocalStorage(
insertSourceInHeap(query_heap, phantom); facade.GetNumberOfNodes());
auto &query_heap = *(engine_working_data.many_to_many_heap);
insertSourceInHeap(query_heap, phantom);
// explore search space // explore search space
while (!query_heap.Empty()) while (!query_heap.Empty())
{ {
forwardRoutingStep(facade, forwardRoutingStep(facade,
row_idx, row_idx,
number_of_targets, number_of_targets,
query_heap, query_heap,
search_space_with_buckets, search_space_with_buckets,
weights_table, weights_table,
durations_table, durations_table,
phantom); phantom);
} }
} }
});
return durations_table; return durations_table;
} }