diff --git a/src/contractor/contractor.cpp b/src/contractor/contractor.cpp index 83556af73..0a81f83e1 100644 --- a/src/contractor/contractor.cpp +++ b/src/contractor/contractor.cpp @@ -210,6 +210,10 @@ namespace struct Segment final { OSMNodeID from, to; + bool operator==(const Segment &other) const + { + return std::tie(from, to) == std::tie(other.from, other.to); + } }; struct SpeedSource final @@ -222,36 +226,56 @@ struct SegmentSpeedSource final { Segment segment; SpeedSource speed_source; + // < operator is overloaded here to return a > comparison to be used by the + // std::lower_bound() call in the find() function + bool operator<(const SegmentSpeedSource &other) const + { + return std::tie(segment.from, segment.to) > std::tie(other.segment.from, other.segment.to); + } }; +struct Turn final +{ + OSMNodeID from, via, to; + bool operator==(const Turn &other) const + { + return std::tie(from, via, to) == std::tie(other.from, other.via, other.to); + } +}; + +struct PenaltySource final +{ + double penalty; + std::uint8_t source; +}; +struct TurnPenaltySource final +{ + Turn segment; + PenaltySource penalty_source; + // < operator is overloaded here to return a > comparison to be used by the + // std::lower_bound() call in the find() function + bool operator<(const TurnPenaltySource &other) const + { + return std::tie(segment.from, segment.via, segment.to) > + std::tie(other.segment.from, other.segment.via, other.segment.to); + } +}; +using TurnPenaltySourceFlatMap = std::vector; using SegmentSpeedSourceFlatMap = std::vector; // Binary Search over a flattened key,val Segment storage -SegmentSpeedSourceFlatMap::iterator find(SegmentSpeedSourceFlatMap &map, const Segment &key) +template +auto find(const FlatMap &map, const SegmentKey &key) { const auto last = end(map); + auto it = std::lower_bound(begin(map), last, key); - const auto by_segment = [](const SegmentSpeedSource &lhs, const SegmentSpeedSource &rhs) { - return std::tie(lhs.segment.from, lhs.segment.to) > - std::tie(rhs.segment.from, rhs.segment.to); - }; - - auto it = std::lower_bound(begin(map), last, SegmentSpeedSource{key, {0, 0}}, by_segment); - - if (it != last && (std::tie(it->segment.from, it->segment.to) == std::tie(key.from, key.to))) + if (it != last && (it->segment == key.segment)) return it; return last; } -// Convenience aliases. TODO: make actual types at some point in time. -// TODO: turn penalties need flat map + binary search optimization, take a look at segment speeds - -using Turn = std::tuple; -using TurnHasher = std::hash; -using PenaltySource = std::pair; -using TurnPenaltySourceMap = tbb::concurrent_unordered_map; - // Functions for parsing files and creating lookup tables SegmentSpeedSourceFlatMap @@ -344,11 +368,14 @@ parse_segment_lookup_from_csv_files(const std::vector &segment_spee return flatten; } -TurnPenaltySourceMap +TurnPenaltySourceFlatMap parse_turn_penalty_lookup_from_csv_files(const std::vector &turn_penalty_filenames) { + using Mutex = tbb::spin_mutex; + // TODO: shares code with turn penalty lookup parse function - TurnPenaltySourceMap map; + TurnPenaltySourceFlatMap map; + Mutex flatten_mutex; const auto parse_turn_penalty_file = [&](const std::size_t idx) { const auto file_id = idx + 1; // starts at one, zero means we assigned the weight @@ -358,6 +385,8 @@ parse_turn_penalty_lookup_from_csv_files(const std::vector &turn_pe if (!turn_penalty_file) throw util::exception{"Unable to open turn penalty file " + filename}; + TurnPenaltySourceFlatMap local; + std::uint64_t from_node_id{}; std::uint64_t via_node_id{}; std::uint64_t to_node_id{}; @@ -383,9 +412,21 @@ parse_turn_penalty_lookup_from_csv_files(const std::vector &turn_pe if (!ok || it != last) throw util::exception{"Turn penalty file " + filename + " malformed"}; - map[std::make_tuple( - OSMNodeID{from_node_id}, OSMNodeID{via_node_id}, OSMNodeID{to_node_id})] = - std::make_pair(penalty, file_id); + TurnPenaltySource val{ + {OSMNodeID{from_node_id}, OSMNodeID{via_node_id}, OSMNodeID{to_node_id}}, + {penalty, static_cast(file_id)}}; + local.push_back(std::move(val)); + } + + util::SimpleLogger().Write() << "Loaded penalty file " << filename << " with " << local.size() + << " turn penalties"; + + { + Mutex::scoped_lock _{flatten_mutex}; + + map.insert(end(map), + std::make_move_iterator(begin(local)), + std::make_move_iterator(end(local))); } }; @@ -470,7 +511,7 @@ EdgeID Contractor::LoadEdgeExpandedGraph( << " edges from the edge based graph"; SegmentSpeedSourceFlatMap segment_speed_lookup; - TurnPenaltySourceMap turn_penalty_lookup; + TurnPenaltySourceFlatMap turn_penalty_lookup; const auto parse_segment_speeds = [&] { if (update_edge_weights) @@ -616,8 +657,8 @@ EdgeID Contractor::LoadEdgeExpandedGraph( const double segment_length = util::coordinate_calculation::greatCircleDistance( util::Coordinate{u->lon, u->lat}, util::Coordinate{v->lon, v->lat}); - auto forward_speed_iter = - find(segment_speed_lookup, Segment{u->node_id, v->node_id}); + auto forward_speed_iter = find( + segment_speed_lookup, SegmentSpeedSource{{u->node_id, v->node_id}, {0, 0}}); if (forward_speed_iter != segment_speed_lookup.end()) { const auto new_segment_weight = getNewWeight(forward_speed_iter, @@ -644,8 +685,9 @@ EdgeID Contractor::LoadEdgeExpandedGraph( const auto current_rev_weight = m_geometry_rev_weight_list[forward_begin + leaf_object.fwd_segment_position]; - const auto reverse_speed_iter = - find(segment_speed_lookup, Segment{v->node_id, u->node_id}); + const auto reverse_speed_iter = find( + segment_speed_lookup, SegmentSpeedSource{{v->node_id, u->node_id}, {0, 0}}); + if (reverse_speed_iter != segment_speed_lookup.end()) { const auto new_segment_weight = getNewWeight(reverse_speed_iter, @@ -793,7 +835,8 @@ EdgeID Contractor::LoadEdgeExpandedGraph( { auto speed_iter = find(segment_speed_lookup, - Segment{previous_osm_node_id, segmentblocks[i].this_osm_node_id}); + SegmentSpeedSource{ + previous_osm_node_id, segmentblocks[i].this_osm_node_id, {0, 0}}); if (speed_iter != segment_speed_lookup.end()) { if (speed_iter->speed_source.speed > 0) @@ -829,16 +872,18 @@ EdgeID Contractor::LoadEdgeExpandedGraph( continue; } - const auto turn_iter = turn_penalty_lookup.find( - std::make_tuple(penaltyblock->from_id, penaltyblock->via_id, penaltyblock->to_id)); + auto turn_iter = + find(turn_penalty_lookup, + TurnPenaltySource{ + penaltyblock->from_id, penaltyblock->via_id, penaltyblock->to_id, {0, 0}}); if (turn_iter != turn_penalty_lookup.end()) { - int new_turn_weight = static_cast(turn_iter->second.first * 10); + int new_turn_weight = static_cast(turn_iter->penalty_source.penalty * 10); if (new_turn_weight + new_weight < compressed_edge_nodes) { util::SimpleLogger().Write(logWARNING) - << "turn penalty " << turn_iter->second.first << " for turn " + << "turn penalty " << turn_iter->penalty_source.penalty << " for turn " << penaltyblock->from_id << ", " << penaltyblock->via_id << ", " << penaltyblock->to_id << " is too negative: clamping turn weight to " << compressed_edge_nodes;