Parallelize ManyToMany plugin

This commit is contained in:
Michael Krasnyk 2017-08-29 20:02:04 +02:00
parent 543048efcc
commit 172a8bdcdb

View File

@ -26,13 +26,12 @@ struct NodeBucket
EdgeDuration duration; EdgeDuration duration;
NodeBucket(NodeID middle_node, unsigned column_index, EdgeWeight weight, EdgeDuration duration) NodeBucket(NodeID middle_node, unsigned column_index, EdgeWeight weight, EdgeDuration duration)
: middle_node(middle_node), column_index(column_index), weight(weight), duration(duration) {} : middle_node(middle_node), column_index(column_index), weight(weight), duration(duration)
{
}
// partial order comparison // partial order comparison
bool operator<(const NodeBucket& rhs) const bool operator<(const NodeBucket &rhs) const { return middle_node < rhs.middle_node; }
{
return middle_node < rhs.middle_node;
}
// functor for equal_range // functor for equal_range
struct Compare struct Compare
@ -257,7 +256,10 @@ void forwardRoutingStep(const DataFacade<Algorithm> &facade,
const auto source_duration = query_heap.GetData(node).duration; const auto source_duration = query_heap.GetData(node).duration;
// check if each encountered node has an entry // check if each encountered node has an entry
const auto &bucket_list = std::equal_range(search_space_with_buckets.begin(), search_space_with_buckets.end(), node, NodeBucket::Compare()); const auto &bucket_list = std::equal_range(search_space_with_buckets.begin(),
search_space_with_buckets.end(),
node,
NodeBucket::Compare());
for (const auto &current_bucket : boost::make_iterator_range(bucket_list)) for (const auto &current_bucket : boost::make_iterator_range(bucket_list))
{ {
// get target id from bucket entry // get target id from bucket entry
@ -335,37 +337,55 @@ std::vector<EdgeWeight> manyToManySearch(SearchEngineData<Algorithm> &engine_wor
std::vector<EdgeWeight> weights_table(number_of_entries, INVALID_EDGE_WEIGHT); std::vector<EdgeWeight> weights_table(number_of_entries, INVALID_EDGE_WEIGHT);
std::vector<EdgeDuration> durations_table(number_of_entries, MAXIMAL_EDGE_DURATION); std::vector<EdgeDuration> durations_table(number_of_entries, MAXIMAL_EDGE_DURATION);
std::mutex lock;
std::vector<NodeBucket> search_space_with_buckets; std::vector<NodeBucket> search_space_with_buckets;
engine_working_data.InitializeOrClearManyToManyThreadLocalStorage(facade.GetNumberOfNodes());
auto &query_heap = *(engine_working_data.many_to_many_heap);
// Backward search for target phantoms // Backward search for target phantoms
for (const auto column_idx : util::irange<std::size_t>(0u, target_indices.size())) tbb::parallel_for(
tbb::blocked_range<std::size_t>{0, target_indices.size()},
[&](const auto &chunk) {
for (auto column_idx = chunk.begin(), end = chunk.end(); column_idx != end;
++column_idx)
{ {
const auto index = target_indices[column_idx]; const auto index = target_indices[column_idx];
const auto &phantom = phantom_nodes[index]; const auto &phantom = phantom_nodes[index];
query_heap.Clear(); engine_working_data.InitializeOrClearManyToManyThreadLocalStorage(
facade.GetNumberOfNodes());
auto &query_heap = *(engine_working_data.many_to_many_heap);
insertTargetInHeap(query_heap, phantom); insertTargetInHeap(query_heap, phantom);
// explore search space // explore search space
std::vector<NodeBucket> local_buckets;
while (!query_heap.Empty()) while (!query_heap.Empty())
{ {
backwardRoutingStep(facade, column_idx, query_heap, search_space_with_buckets, phantom); backwardRoutingStep(facade, column_idx, query_heap, local_buckets, phantom);
}
} }
std::sort(search_space_with_buckets.begin(), search_space_with_buckets.end()); { // Insert local buckets into the global search space
std::lock_guard<std::mutex> guard{lock};
search_space_with_buckets.insert(std::end(search_space_with_buckets),
std::begin(local_buckets),
std::end(local_buckets));
}
}
});
tbb::parallel_sort(search_space_with_buckets.begin(), search_space_with_buckets.end());
// For each source do forward search // For each source do forward search
for (const auto row_idx : util::irange<std::size_t>(0, source_indices.size())) tbb::parallel_for(tbb::blocked_range<std::size_t>{0, source_indices.size()},
[&](const auto &chunk) {
for (auto row_idx = chunk.begin(), end = chunk.end(); row_idx != end;
++row_idx)
{ {
const auto index = source_indices[row_idx]; const auto index = source_indices[row_idx];
const auto &phantom = phantom_nodes[index]; const auto &phantom = phantom_nodes[index];
// clear heap and insert source nodes // clear heap and insert source nodes
query_heap.Clear(); engine_working_data.InitializeOrClearManyToManyThreadLocalStorage(
facade.GetNumberOfNodes());
auto &query_heap = *(engine_working_data.many_to_many_heap);
insertSourceInHeap(query_heap, phantom); insertSourceInHeap(query_heap, phantom);
// explore search space // explore search space
@ -381,6 +401,7 @@ std::vector<EdgeWeight> manyToManySearch(SearchEngineData<Algorithm> &engine_wor
phantom); phantom);
} }
} }
});
return durations_table; return durations_table;
} }