From 1309dd2a0f5b58360b0fd82cbc9dca36740d5852 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Konstantin=20K=C3=A4fer?= Date: Mon, 11 Jul 2016 17:44:58 +0200 Subject: [PATCH] Switch profiles from Lua to library interface There's now an abstracted interface and no direct calls to Lua anymore. fixes #1974 --- .../extractor/edge_based_graph_factory.hpp | 14 +- include/extractor/extraction_containers.hpp | 8 +- include/extractor/extractor.hpp | 6 +- include/extractor/extractor_config.hpp | 1 - include/extractor/restriction_parser.hpp | 6 +- include/extractor/scripting_environment.hpp | 70 ++++--- .../extractor/scripting_environment_lua.hpp | 80 ++++++++ include/extractor/suffix_table.hpp | 7 +- src/extractor/edge_based_graph_factory.cpp | 40 +--- src/extractor/extraction_containers.cpp | 26 +-- src/extractor/extractor.cpp | 113 ++++-------- src/extractor/restriction_parser.cpp | 49 ++--- ...ment.cpp => scripting_environment_lua.cpp} | 173 +++++++++++++++++- src/extractor/suffix_table.cpp | 27 +-- src/tools/extract.cpp | 7 +- 15 files changed, 382 insertions(+), 245 deletions(-) create mode 100644 include/extractor/scripting_environment_lua.hpp rename src/extractor/{scripting_environment.cpp => scripting_environment_lua.cpp} (57%) diff --git a/include/extractor/edge_based_graph_factory.hpp b/include/extractor/edge_based_graph_factory.hpp index f8d1327ec..eb9ec351a 100644 --- a/include/extractor/edge_based_graph_factory.hpp +++ b/include/extractor/edge_based_graph_factory.hpp @@ -36,13 +36,13 @@ #include -struct lua_State; - namespace osrm { namespace extractor { +class ScriptingEnvironment; + namespace lookup { // Set to 1 byte alignment @@ -95,9 +95,9 @@ class EdgeBasedGraphFactory const std::vector &turn_lane_offsets, const std::vector &turn_lane_masks); - void Run(const std::string &original_edge_data_filename, + void Run(ScriptingEnvironment &scripting_environment, + const std::string &original_edge_data_filename, const std::string &turn_lane_data_filename, - lua_State *lua_state, const std::string &edge_segment_lookup_filename, const std::string &edge_penalty_filename, const bool generate_edge_lookup); @@ -127,8 +127,6 @@ class EdgeBasedGraphFactory const NodeID w, const double angle) const; - std::int32_t GetTurnPenalty(double angle, lua_State *lua_state) const; - private: using EdgeData = util::NodeBasedDynamicGraph::EdgeData; @@ -162,9 +160,9 @@ class EdgeBasedGraphFactory void CompressGeometry(); unsigned RenumberEdges(); void GenerateEdgeExpandedNodes(); - void GenerateEdgeExpandedEdges(const std::string &original_edge_data_filename, + void GenerateEdgeExpandedEdges(ScriptingEnvironment &scripting_environment, + const std::string &original_edge_data_filename, const std::string &turn_lane_data_filename, - lua_State *lua_state, const std::string &edge_segment_lookup_filename, const std::string &edge_fixed_penalties_filename, const bool generate_edge_lookup); diff --git a/include/extractor/extraction_containers.hpp b/include/extractor/extraction_containers.hpp index ccd710453..77d6fd2cb 100644 --- a/include/extractor/extraction_containers.hpp +++ b/include/extractor/extraction_containers.hpp @@ -34,7 +34,7 @@ class ExtractionContainers #endif void PrepareNodes(); void PrepareRestrictions(); - void PrepareEdges(lua_State *segment_state); + void PrepareEdges(ScriptingEnvironment &scripting_environment); void WriteNodes(std::ofstream &file_out_stream) const; void WriteRestrictions(const std::string &restrictions_file_name) const; @@ -69,11 +69,11 @@ class ExtractionContainers ExtractionContainers(); - void PrepareData(const std::string &output_file_name, + void PrepareData(ScriptingEnvironment &scripting_environment, + const std::string &output_file_name, const std::string &restrictions_file_name, const std::string &names_file_name, - const std::string &turn_lane_file_name, - lua_State *segment_state); + const std::string &turn_lane_file_name); }; } } diff --git a/include/extractor/extractor.hpp b/include/extractor/extractor.hpp index fa355c814..0aa3af545 100644 --- a/include/extractor/extractor.hpp +++ b/include/extractor/extractor.hpp @@ -43,20 +43,20 @@ namespace osrm namespace extractor { +class ScriptingEnvironment; struct ProfileProperties; class Extractor { public: Extractor(ExtractorConfig extractor_config) : config(std::move(extractor_config)) {} - int run(); + int run(ScriptingEnvironment &scripting_environment); private: ExtractorConfig config; std::pair - BuildEdgeExpandedGraph(lua_State *lua_state, - const ProfileProperties &profile_properties, + BuildEdgeExpandedGraph(ScriptingEnvironment &scripting_environment, std::vector &internal_to_external_node_map, std::vector &node_based_edge_list, std::vector &node_is_startpoint, diff --git a/include/extractor/extractor_config.hpp b/include/extractor/extractor_config.hpp index d0b5fd2c6..5e15e76fa 100644 --- a/include/extractor/extractor_config.hpp +++ b/include/extractor/extractor_config.hpp @@ -77,7 +77,6 @@ struct ExtractorConfig intersection_class_data_output_path = basepath + ".osrm.icd"; } - boost::filesystem::path config_file_path; boost::filesystem::path input_path; boost::filesystem::path profile_path; diff --git a/include/extractor/restriction_parser.hpp b/include/extractor/restriction_parser.hpp index 77926c991..c5eede2f8 100644 --- a/include/extractor/restriction_parser.hpp +++ b/include/extractor/restriction_parser.hpp @@ -8,7 +8,6 @@ #include #include -struct lua_State; namespace osmium { class Relation; @@ -19,7 +18,7 @@ namespace osrm namespace extractor { -struct ProfileProperties; +class ScriptingEnvironment; /** * Parses the relations that represents turn restrictions. @@ -42,11 +41,10 @@ struct ProfileProperties; class RestrictionParser { public: - RestrictionParser(lua_State *lua_state, const ProfileProperties &properties); + RestrictionParser(ScriptingEnvironment &scripting_environment); boost::optional TryParse(const osmium::Relation &relation) const; private: - void ReadRestrictionExceptions(lua_State *lua_state); bool ShouldIgnoreRestriction(const std::string &except_tag_string) const; std::vector restriction_exceptions; diff --git a/include/extractor/scripting_environment.hpp b/include/extractor/scripting_environment.hpp index 16c8da185..a450b3154 100644 --- a/include/extractor/scripting_environment.hpp +++ b/include/extractor/scripting_environment.hpp @@ -1,52 +1,70 @@ #ifndef SCRIPTING_ENVIRONMENT_HPP #define SCRIPTING_ENVIRONMENT_HPP +#include "extractor/guidance/turn_lane_types.hpp" +#include "extractor/internal_extractor_edge.hpp" #include "extractor/profile_properties.hpp" -#include "extractor/raster_source.hpp" +#include "extractor/restriction.hpp" -#include "util/lua_util.hpp" +#include + +#include + +#include -#include -#include #include -#include +#include -struct lua_State; +namespace osmium +{ +class Node; +class Way; +} namespace osrm { + +namespace util +{ +struct Coordinate; +} + namespace extractor { +class RestrictionParser; +struct ExtractionNode; +struct ExtractionWay; + /** - * Creates a lua context and binds osmium way, node and relation objects and - * ExtractionWay and ExtractionNode to lua objects. - * - * Each thread has its own lua state which is implemented with thread specific - * storage from TBB. + * Abstract class that handles processing osmium ways, nodes and relation objects by applying + * user supplied profiles. */ class ScriptingEnvironment { public: - struct Context - { - ProfileProperties properties; - SourceContainer sources; - util::LuaState state; - }; - - explicit ScriptingEnvironment(const std::string &file_name); - + ScriptingEnvironment() = default; ScriptingEnvironment(const ScriptingEnvironment &) = delete; ScriptingEnvironment &operator=(const ScriptingEnvironment &) = delete; + virtual ~ScriptingEnvironment() = default; - Context &GetContex(); + virtual const ProfileProperties &GetProfileProperties() = 0; - private: - void InitContext(Context &context); - std::mutex init_mutex; - std::string file_name; - tbb::enumerable_thread_specific> script_contexts; + virtual std::vector GetNameSuffixList() = 0; + virtual std::vector GetExceptions() = 0; + virtual void SetupSources() = 0; + virtual int32_t GetTurnPenalty(double angle) = 0; + virtual void ProcessSegment(const osrm::util::Coordinate &source, + const osrm::util::Coordinate &target, + double distance, + InternalExtractorEdge::WeightData &weight) = 0; + virtual void + ProcessElements(const std::vector &osm_elements, + const RestrictionParser &restriction_parser, + tbb::concurrent_vector> &resulting_nodes, + tbb::concurrent_vector> &resulting_ways, + tbb::concurrent_vector> + &resulting_restrictions) = 0; }; } } diff --git a/include/extractor/scripting_environment_lua.hpp b/include/extractor/scripting_environment_lua.hpp new file mode 100644 index 000000000..908a4d5d2 --- /dev/null +++ b/include/extractor/scripting_environment_lua.hpp @@ -0,0 +1,80 @@ +#ifndef SCRIPTING_ENVIRONMENT_LUA_HPP +#define SCRIPTING_ENVIRONMENT_LUA_HPP + +#include "extractor/scripting_environment.hpp" + +#include "extractor/raster_source.hpp" + +#include "util/lua_util.hpp" + +#include + +#include +#include +#include + +struct lua_State; + +namespace osrm +{ +namespace extractor +{ + +struct LuaScriptingContext final +{ + void processNode(const osmium::Node &, ExtractionNode &result); + void processWay(const osmium::Way &, ExtractionWay &result); + + ProfileProperties properties; + SourceContainer sources; + util::LuaState state; + + bool has_turn_penalty_function; + bool has_node_function; + bool has_way_function; + bool has_segment_function; +}; + +/** + * Creates a lua context and binds osmium way, node and relation objects and + * ExtractionWay and ExtractionNode to lua objects. + * + * Each thread has its own lua state which is implemented with thread specific + * storage from TBB. + */ +class LuaScriptingEnvironment final : public ScriptingEnvironment +{ + public: + explicit LuaScriptingEnvironment(const std::string &file_name); + ~LuaScriptingEnvironment() override = default; + + const ProfileProperties& GetProfileProperties() override; + + LuaScriptingContext &GetLuaContext(); + + std::vector GetNameSuffixList() override; + std::vector GetExceptions() override; + void SetupSources() override; + int32_t GetTurnPenalty(double angle) override; + void ProcessSegment(const osrm::util::Coordinate &source, + const osrm::util::Coordinate &target, + double distance, + InternalExtractorEdge::WeightData &weight) override; + void + ProcessElements(const std::vector &osm_elements, + const RestrictionParser &restriction_parser, + tbb::concurrent_vector> &resulting_nodes, + tbb::concurrent_vector> &resulting_ways, + tbb::concurrent_vector> + &resulting_restrictions) override; + + private: + void InitContext(LuaScriptingContext &context); + std::mutex init_mutex; + std::string file_name; + tbb::enumerable_thread_specific> script_contexts; +}; +} +} + +#endif /* SCRIPTING_ENVIRONMENT_LUA_HPP */ diff --git a/include/extractor/suffix_table.hpp b/include/extractor/suffix_table.hpp index af5f01649..31830b90e 100644 --- a/include/extractor/suffix_table.hpp +++ b/include/extractor/suffix_table.hpp @@ -4,19 +4,20 @@ #include #include -struct lua_State; - namespace osrm { namespace extractor { + +class ScriptingEnvironment; + // A table containing suffixes. // At the moment, it is only a front for an unordered set. At some point we might want to make it // country dependent and have it behave accordingly class SuffixTable final { public: - SuffixTable(lua_State *lua_state); + SuffixTable(ScriptingEnvironment &scripting_environment); // check whether a string is part of the know suffix list bool isSuffix(const std::string &possible_suffix) const; diff --git a/src/extractor/edge_based_graph_factory.cpp b/src/extractor/edge_based_graph_factory.cpp index d7182f0e1..1690bbd70 100644 --- a/src/extractor/edge_based_graph_factory.cpp +++ b/src/extractor/edge_based_graph_factory.cpp @@ -4,7 +4,6 @@ #include "util/coordinate_calculation.hpp" #include "util/exception.hpp" #include "util/integer_range.hpp" -#include "util/lua_util.hpp" #include "util/percent.hpp" #include "util/simple_logger.hpp" #include "util/timing_util.hpp" @@ -12,6 +11,7 @@ #include "extractor/guidance/toolkit.hpp" #include "extractor/guidance/turn_analysis.hpp" #include "extractor/guidance/turn_lane_handler.hpp" +#include "extractor/scripting_environment.hpp" #include "extractor/suffix_table.hpp" #include @@ -182,9 +182,9 @@ void EdgeBasedGraphFactory::FlushVectorToStream( original_edge_data_vector.clear(); } -void EdgeBasedGraphFactory::Run(const std::string &original_edge_data_filename, +void EdgeBasedGraphFactory::Run(ScriptingEnvironment &scripting_environment, + const std::string &original_edge_data_filename, const std::string &turn_lane_data_filename, - lua_State *lua_state, const std::string &edge_segment_lookup_filename, const std::string &edge_penalty_filename, const bool generate_edge_lookup) @@ -199,9 +199,9 @@ void EdgeBasedGraphFactory::Run(const std::string &original_edge_data_filename, TIMER_STOP(generate_nodes); TIMER_START(generate_edges); - GenerateEdgeExpandedEdges(original_edge_data_filename, + GenerateEdgeExpandedEdges(scripting_environment, + original_edge_data_filename, turn_lane_data_filename, - lua_state, edge_segment_lookup_filename, edge_penalty_filename, generate_edge_lookup); @@ -298,18 +298,15 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedNodes() /// Actually it also generates OriginalEdgeData and serializes them... void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges( + ScriptingEnvironment &scripting_environment, const std::string &original_edge_data_filename, const std::string &turn_lane_data_filename, - lua_State *lua_state, const std::string &edge_segment_lookup_filename, const std::string &edge_fixed_penalties_filename, const bool generate_edge_lookup) { util::SimpleLogger().Write() << "generating edge-expanded edges"; - BOOST_ASSERT(lua_state != nullptr); - const bool use_turn_function = util::luaFunctionExists(lua_state, "turn_function"); - std::size_t node_based_edge_counter = 0; std::size_t original_edges_counter = 0; restricted_turns_counter = 0; @@ -338,7 +335,7 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges( // Three nested loop look super-linear, but we are dealing with a (kind of) // linear number of turns only. util::Percent progress(m_node_based_graph->GetNumberOfNodes()); - SuffixTable street_name_suffix_table(lua_state); + SuffixTable street_name_suffix_table(scripting_environment); guidance::TurnAnalysis turn_analysis(*m_node_based_graph, m_node_info_list, *m_restriction_map, @@ -410,8 +407,6 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges( for (const auto turn : possible_turns) { - const double turn_angle = turn.angle; - // only add an edge if turn is not prohibited const EdgeData &edge_data1 = m_node_based_graph->GetEdgeData(edge_from_u); const EdgeData &edge_data2 = m_node_based_graph->GetEdgeData(turn.eid); @@ -427,8 +422,7 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges( distance += profile_properties.traffic_signal_penalty; } - const int turn_penalty = - use_turn_function ? GetTurnPenalty(turn_angle, lua_state) : 0; + const int32_t turn_penalty = scripting_environment.GetTurnPenalty(180. - turn.angle); const auto turn_instruction = turn.instruction; if (guidance::isUturn(turn_instruction)) @@ -615,23 +609,5 @@ std::vector EdgeBasedGraphFactory::GetEntryClasses() return result; } -int EdgeBasedGraphFactory::GetTurnPenalty(double angle, lua_State *lua_state) const -{ - BOOST_ASSERT(lua_state != nullptr); - try - { - // call lua profile to compute turn penalty - double penalty = luabind::call_function(lua_state, "turn_function", 180. - angle); - BOOST_ASSERT(penalty < std::numeric_limits::max()); - BOOST_ASSERT(penalty > std::numeric_limits::min()); - return boost::numeric_cast(penalty); - } - catch (const luabind::error &er) - { - util::SimpleLogger().Write(logWARNING) << er.what(); - } - return 0; -} - } // namespace extractor } // namespace osrm diff --git a/src/extractor/extraction_containers.cpp b/src/extractor/extraction_containers.cpp index 908083ff3..ba5e81e59 100644 --- a/src/extractor/extraction_containers.cpp +++ b/src/extractor/extraction_containers.cpp @@ -7,7 +7,6 @@ #include "util/exception.hpp" #include "util/fingerprint.hpp" #include "util/io.hpp" -#include "util/lua_util.hpp" #include "util/simple_logger.hpp" #include "util/timing_util.hpp" @@ -17,8 +16,6 @@ #include #include -#include - #include #include @@ -137,11 +134,11 @@ ExtractionContainers::ExtractionContainers() * - merge edges with nodes to include location of start/end points and serialize * */ -void ExtractionContainers::PrepareData(const std::string &output_file_name, +void ExtractionContainers::PrepareData(ScriptingEnvironment &scripting_environment, + const std::string &output_file_name, const std::string &restrictions_file_name, const std::string &name_file_name, - const std::string &turn_lane_file_name, - lua_State *segment_state) + const std::string &turn_lane_file_name) { try { @@ -152,7 +149,7 @@ void ExtractionContainers::PrepareData(const std::string &output_file_name, PrepareNodes(); WriteNodes(file_out_stream); - PrepareEdges(segment_state); + PrepareEdges(scripting_environment); WriteEdges(file_out_stream); PrepareRestrictions(); @@ -304,7 +301,7 @@ void ExtractionContainers::PrepareNodes() std::cout << "ok, after " << TIMER_SEC(id_map) << "s" << std::endl; } -void ExtractionContainers::PrepareEdges(lua_State *segment_state) +void ExtractionContainers::PrepareEdges(ScriptingEnvironment &scripting_environment) { // Sort edges by start. std::cout << "[extractor] Sorting edges by start ... " << std::flush; @@ -386,8 +383,6 @@ void ExtractionContainers::PrepareEdges(lua_State *segment_state) const auto all_edges_list_end_ = all_edges_list.end(); const auto all_nodes_list_end_ = all_nodes_list.end(); - const auto has_segment_function = util::luaFunctionExists(segment_state, "segment_function"); - while (edge_iterator != all_edges_list_end_ && node_iterator != all_nodes_list_end_) { // skip all invalid edges @@ -423,15 +418,8 @@ void ExtractionContainers::PrepareEdges(lua_State *segment_state) edge_iterator->source_coordinate, util::Coordinate(node_iterator->lon, node_iterator->lat)); - if (has_segment_function) - { - luabind::call_function(segment_state, - "segment_function", - boost::cref(edge_iterator->source_coordinate), - boost::cref(*node_iterator), - distance, - boost::ref(edge_iterator->weight_data)); - } + scripting_environment.ProcessSegment( + edge_iterator->source_coordinate, *node_iterator, distance, edge_iterator->weight_data); const double weight = [distance](const InternalExtractorEdge::WeightData &data) { switch (data.type) diff --git a/src/extractor/extractor.cpp b/src/extractor/extractor.cpp index 45ed06e91..b2aa74c41 100644 --- a/src/extractor/extractor.cpp +++ b/src/extractor/extractor.cpp @@ -11,7 +11,6 @@ #include "extractor/raster_source.hpp" #include "util/graph_loader.hpp" #include "util/io.hpp" -#include "util/lua_util.hpp" #include "util/make_unique.hpp" #include "util/name_table.hpp" #include "util/range_table.hpp" @@ -29,11 +28,9 @@ #include #include -#include - #include -#include +#include #include #include @@ -74,11 +71,8 @@ namespace extractor * graph * */ -int Extractor::run() +int Extractor::run(ScriptingEnvironment &scripting_environment) { - // setup scripting environment - ScriptingEnvironment scripting_environment(config.profile_path.string().c_str()); - try { util::LogPolicy::GetInstance().Unmute(); @@ -90,7 +84,9 @@ int Extractor::run() tbb::task_scheduler_init init(number_of_threads); util::SimpleLogger().Write() << "Input file: " << config.input_path.filename().string(); - util::SimpleLogger().Write() << "Profile: " << config.profile_path.filename().string(); + if (!config.profile_path.empty()) { + util::SimpleLogger().Write() << "Profile: " << config.profile_path.filename().string(); + } util::SimpleLogger().Write() << "Threads: " << number_of_threads; ExtractionContainers extraction_containers; @@ -100,21 +96,15 @@ int Extractor::run() osmium::io::Reader reader(input_file); const osmium::io::Header header = reader.header(); - std::atomic number_of_nodes{0}; - std::atomic number_of_ways{0}; - std::atomic number_of_relations{0}; - std::atomic number_of_others{0}; + unsigned number_of_nodes = 0; + unsigned number_of_ways = 0; + unsigned number_of_relations = 0; util::SimpleLogger().Write() << "Parsing in progress.."; TIMER_START(parsing); - auto &main_context = scripting_environment.GetContex(); - // setup raster sources - if (util::luaFunctionExists(main_context.state, "source_function")) - { - luabind::call_function(main_context.state, "source_function"); - } + scripting_environment.SetupSources(); std::string generator = header.get("generator"); if (generator.empty()) @@ -140,7 +130,7 @@ int Extractor::run() tbb::concurrent_vector> resulting_restrictions; // setup restriction parser - const RestrictionParser restriction_parser(main_context.state, main_context.properties); + const RestrictionParser restriction_parser(scripting_environment); while (const osmium::memory::Buffer buffer = reader.read()) { @@ -156,52 +146,14 @@ int Extractor::run() resulting_ways.clear(); resulting_restrictions.clear(); - // parse OSM entities in parallel, store in resulting vectors - tbb::parallel_for( - tbb::blocked_range(0, osm_elements.size()), - [&](const tbb::blocked_range &range) { - ExtractionNode result_node; - ExtractionWay result_way; - auto &local_context = scripting_environment.GetContex(); + scripting_environment.ProcessElements(osm_elements, + restriction_parser, + resulting_nodes, + resulting_ways, + resulting_restrictions); - for (auto x = range.begin(), end = range.end(); x != end; ++x) - { - const auto entity = osm_elements[x]; - - switch (entity->type()) - { - case osmium::item_type::node: - result_node.clear(); - ++number_of_nodes; - luabind::call_function( - local_context.state, - "node_function", - boost::cref(static_cast(*entity)), - boost::ref(result_node)); - resulting_nodes.push_back(std::make_pair(x, std::move(result_node))); - break; - case osmium::item_type::way: - result_way.clear(); - ++number_of_ways; - luabind::call_function( - local_context.state, - "way_function", - boost::cref(static_cast(*entity)), - boost::ref(result_way)); - resulting_ways.push_back(std::make_pair(x, std::move(result_way))); - break; - case osmium::item_type::relation: - ++number_of_relations; - resulting_restrictions.push_back(restriction_parser.TryParse( - static_cast(*entity))); - break; - default: - ++number_of_others; - break; - } - } - }); + number_of_nodes += resulting_nodes.size(); // put parsed objects thru extractor callbacks for (const auto &result : resulting_nodes) { @@ -209,11 +161,13 @@ int Extractor::run() static_cast(*(osm_elements[result.first])), result.second); } + number_of_ways += resulting_ways.size(); for (const auto &result : resulting_ways) { extractor_callbacks->ProcessWay( static_cast(*(osm_elements[result.first])), result.second); } + number_of_relations += resulting_restrictions.size(); for (const auto &result : resulting_restrictions) { extractor_callbacks->ProcessRestriction(result); @@ -223,10 +177,9 @@ int Extractor::run() util::SimpleLogger().Write() << "Parsing finished after " << TIMER_SEC(parsing) << " seconds"; - util::SimpleLogger().Write() << "Raw input contains " << number_of_nodes.load() - << " nodes, " << number_of_ways.load() << " ways, and " - << number_of_relations.load() << " relations, and " - << number_of_others.load() << " unknown entities"; + util::SimpleLogger().Write() << "Raw input contains " << number_of_nodes << " nodes, " + << number_of_ways << " ways, and " << number_of_relations + << " relations"; extractor_callbacks.reset(); @@ -236,13 +189,14 @@ int Extractor::run() return 1; } - extraction_containers.PrepareData(config.output_file_name, + extraction_containers.PrepareData(scripting_environment, + config.output_file_name, config.restriction_file_name, config.names_file_name, - config.turn_lane_descriptions_file_name, - main_context.state); + config.turn_lane_descriptions_file_name); - WriteProfileProperties(config.profile_properties_output_path, main_context.properties); + WriteProfileProperties(config.profile_properties_output_path, + scripting_environment.GetProfileProperties()); TIMER_STOP(extracting); util::SimpleLogger().Write() << "extraction finished after " << TIMER_SEC(extracting) @@ -261,9 +215,6 @@ int Extractor::run() // that is better for routing. Every edge becomes a node, and every valid // movement (e.g. turn from A->B, and B->A) becomes an edge // - - auto &main_context = scripting_environment.GetContex(); - util::SimpleLogger().Write() << "Generating edge-expanded graph representation"; TIMER_START(expansion); @@ -273,8 +224,7 @@ int Extractor::run() std::vector node_is_startpoint; std::vector edge_based_node_weights; std::vector internal_to_external_node_map; - auto graph_size = BuildEdgeExpandedGraph(main_context.state, - main_context.properties, + auto graph_size = BuildEdgeExpandedGraph(scripting_environment, internal_to_external_node_map, edge_based_node_list, node_is_startpoint, @@ -477,8 +427,7 @@ Extractor::LoadNodeBasedGraph(std::unordered_set &barrier_nodes, \brief Building an edge-expanded graph from node-based input and turn restrictions */ std::pair -Extractor::BuildEdgeExpandedGraph(lua_State *lua_state, - const ProfileProperties &profile_properties, +Extractor::BuildEdgeExpandedGraph(ScriptingEnvironment &scripting_environment, std::vector &internal_to_external_node_map, std::vector &node_based_edge_list, std::vector &node_is_startpoint, @@ -520,14 +469,14 @@ Extractor::BuildEdgeExpandedGraph(lua_State *lua_state, traffic_lights, std::const_pointer_cast(restriction_map), internal_to_external_node_map, - profile_properties, + scripting_environment.GetProfileProperties(), name_table, turn_lane_offsets, turn_lane_masks); - edge_based_graph_factory.Run(config.edge_output_path, + edge_based_graph_factory.Run(scripting_environment, + config.edge_output_path, config.turn_lane_data_file_name, - lua_state, config.edge_segment_lookup_path, config.edge_penalty_path, config.generate_edge_lookup); diff --git a/src/extractor/restriction_parser.cpp b/src/extractor/restriction_parser.cpp index a6cc0e7c6..f7840974d 100644 --- a/src/extractor/restriction_parser.cpp +++ b/src/extractor/restriction_parser.cpp @@ -1,9 +1,9 @@ #include "extractor/restriction_parser.hpp" #include "extractor/profile_properties.hpp" +#include "extractor/scripting_environment.hpp" #include "extractor/external_memory_node.hpp" -#include "util/exception.hpp" -#include "util/lua_util.hpp" + #include "util/simple_logger.hpp" #include @@ -24,43 +24,26 @@ namespace osrm namespace extractor { -namespace -{ -int luaErrorCallback(lua_State *lua_state) -{ - std::string error_msg = lua_tostring(lua_state, -1); - throw util::exception("ERROR occurred in profile script:\n" + error_msg); -} -} - -RestrictionParser::RestrictionParser(lua_State *lua_state, const ProfileProperties &properties) - : use_turn_restrictions(properties.use_turn_restrictions) +RestrictionParser::RestrictionParser(ScriptingEnvironment &scripting_environment) + : use_turn_restrictions(scripting_environment.GetProfileProperties().use_turn_restrictions) { if (use_turn_restrictions) { - ReadRestrictionExceptions(lua_state); - } -} - -void RestrictionParser::ReadRestrictionExceptions(lua_State *lua_state) -{ - if (util::luaFunctionExists(lua_state, "get_exceptions")) - { - luabind::set_pcall_callback(&luaErrorCallback); - // get list of turn restriction exceptions - luabind::call_function( - lua_state, "get_exceptions", boost::ref(restriction_exceptions)); + restriction_exceptions = scripting_environment.GetExceptions(); const unsigned exception_count = restriction_exceptions.size(); - util::SimpleLogger().Write() << "Found " << exception_count - << " exceptions to turn restrictions:"; - for (const std::string &str : restriction_exceptions) + if (exception_count) { - util::SimpleLogger().Write() << " " << str; + util::SimpleLogger().Write() << "Found " << exception_count + << " exceptions to turn restrictions:"; + for (const std::string &str : restriction_exceptions) + { + util::SimpleLogger().Write() << " " << str; + } + } + else + { + util::SimpleLogger().Write() << "Found no exceptions to turn restrictions"; } - } - else - { - util::SimpleLogger().Write() << "Found no exceptions to turn restrictions"; } } diff --git a/src/extractor/scripting_environment.cpp b/src/extractor/scripting_environment_lua.cpp similarity index 57% rename from src/extractor/scripting_environment.cpp rename to src/extractor/scripting_environment_lua.cpp index 899537e44..eaba199c7 100644 --- a/src/extractor/scripting_environment.cpp +++ b/src/extractor/scripting_environment_lua.cpp @@ -1,4 +1,4 @@ -#include "extractor/scripting_environment.hpp" +#include "extractor/scripting_environment_lua.hpp" #include "extractor/external_memory_node.hpp" #include "extractor/extraction_helper_functions.hpp" @@ -7,6 +7,7 @@ #include "extractor/internal_extractor_edge.hpp" #include "extractor/profile_properties.hpp" #include "extractor/raster_source.hpp" +#include "extractor/restriction_parser.hpp" #include "util/exception.hpp" #include "util/lua_util.hpp" #include "util/make_unique.hpp" @@ -19,6 +20,8 @@ #include +#include + #include namespace osrm @@ -58,12 +61,13 @@ int luaErrorCallback(lua_State *state) } } -ScriptingEnvironment::ScriptingEnvironment(const std::string &file_name) : file_name(file_name) +LuaScriptingEnvironment::LuaScriptingEnvironment(const std::string &file_name) + : file_name(file_name) { util::SimpleLogger().Write() << "Using script " << file_name; } -void ScriptingEnvironment::InitContext(ScriptingEnvironment::Context &context) +void LuaScriptingEnvironment::InitContext(LuaScriptingContext &context) { typedef double (osmium::Location::*location_member_ptr_type)() const; @@ -187,21 +191,180 @@ void ScriptingEnvironment::InitContext(ScriptingEnvironment::Context &context) error_stream << error_msg; throw util::exception("ERROR occurred in profile script:\n" + error_stream.str()); } + + context.has_turn_penalty_function = util::luaFunctionExists(context.state, "turn_function"); + context.has_node_function = util::luaFunctionExists(context.state, "node_function"); + context.has_way_function = util::luaFunctionExists(context.state, "way_function"); + context.has_segment_function = util::luaFunctionExists(context.state, "segment_function"); } -ScriptingEnvironment::Context &ScriptingEnvironment::GetContex() +const ProfileProperties &LuaScriptingEnvironment::GetProfileProperties() +{ + return GetLuaContext().properties; +} + +LuaScriptingContext &LuaScriptingEnvironment::GetLuaContext() { std::lock_guard lock(init_mutex); bool initialized = false; auto &ref = script_contexts.local(initialized); if (!initialized) { - ref = util::make_unique(); + ref = util::make_unique(); InitContext(*ref); } luabind::set_pcall_callback(&luaErrorCallback); return *ref; } + +void LuaScriptingEnvironment::ProcessElements( + const std::vector &osm_elements, + const RestrictionParser &restriction_parser, + tbb::concurrent_vector> &resulting_nodes, + tbb::concurrent_vector> &resulting_ways, + tbb::concurrent_vector> &resulting_restrictions) +{ + // parse OSM entities in parallel, store in resulting vectors + tbb::parallel_for( + tbb::blocked_range(0, osm_elements.size()), + [&](const tbb::blocked_range &range) { + ExtractionNode result_node; + ExtractionWay result_way; + auto &local_context = this->GetLuaContext(); + + for (auto x = range.begin(), end = range.end(); x != end; ++x) + { + const auto entity = osm_elements[x]; + + switch (entity->type()) + { + case osmium::item_type::node: + result_node.clear(); + if (local_context.has_node_function) + { + local_context.processNode(static_cast(*entity), + result_node); + } + resulting_nodes.push_back(std::make_pair(x, std::move(result_node))); + break; + case osmium::item_type::way: + result_way.clear(); + if (local_context.has_way_function) + { + local_context.processWay(static_cast(*entity), + result_way); + } + resulting_ways.push_back(std::make_pair(x, std::move(result_way))); + break; + case osmium::item_type::relation: + resulting_restrictions.push_back(restriction_parser.TryParse( + static_cast(*entity))); + break; + default: + break; + } + } + }); +} + +std::vector LuaScriptingEnvironment::GetNameSuffixList() +{ + auto &context = GetLuaContext(); + BOOST_ASSERT(context.state != nullptr); + if (!util::luaFunctionExists(context.state, "get_name_suffix_list")) + return {}; + + std::vector suffixes_vector; + try + { + // call lua profile to compute turn penalty + luabind::call_function( + context.state, "get_name_suffix_list", boost::ref(suffixes_vector)); + } + catch (const luabind::error &er) + { + util::SimpleLogger().Write(logWARNING) << er.what(); + } + + return suffixes_vector; +} + +std::vector LuaScriptingEnvironment::GetExceptions() +{ + auto &context = GetLuaContext(); + BOOST_ASSERT(context.state != nullptr); + std::vector restriction_exceptions; + if (util::luaFunctionExists(context.state, "get_exceptions")) + { + // get list of turn restriction exceptions + luabind::call_function( + context.state, "get_exceptions", boost::ref(restriction_exceptions)); + } + return restriction_exceptions; +} + +void LuaScriptingEnvironment::SetupSources() +{ + auto &context = GetLuaContext(); + BOOST_ASSERT(context.state != nullptr); + if (util::luaFunctionExists(context.state, "source_function")) + { + luabind::call_function(context.state, "source_function"); + } +} + +int32_t LuaScriptingEnvironment::GetTurnPenalty(const double angle) +{ + auto &context = GetLuaContext(); + if (context.has_turn_penalty_function) + { + BOOST_ASSERT(context.state != nullptr); + try + { + // call lua profile to compute turn penalty + const double penalty = + luabind::call_function(context.state, "turn_function", angle); + BOOST_ASSERT(penalty < std::numeric_limits::max()); + BOOST_ASSERT(penalty > std::numeric_limits::min()); + return boost::numeric_cast(penalty); + } + catch (const luabind::error &er) + { + util::SimpleLogger().Write(logWARNING) << er.what(); + } + } + return 0; +} + +void LuaScriptingEnvironment::ProcessSegment(const osrm::util::Coordinate &source, + const osrm::util::Coordinate &target, + double distance, + InternalExtractorEdge::WeightData &weight) +{ + auto &context = GetLuaContext(); + if (context.has_segment_function) + { + BOOST_ASSERT(context.state != nullptr); + luabind::call_function(context.state, + "segment_function", + boost::cref(source), + boost::cref(target), + distance, + boost::ref(weight)); + } +} + +void LuaScriptingContext::processNode(const osmium::Node &node, ExtractionNode &result) +{ + BOOST_ASSERT(state != nullptr); + luabind::call_function(state, "node_function", boost::cref(node), boost::ref(result)); +} + +void LuaScriptingContext::processWay(const osmium::Way &way, ExtractionWay &result) +{ + BOOST_ASSERT(state != nullptr); + luabind::call_function(state, "way_function", boost::cref(way), boost::ref(result)); +} } } diff --git a/src/extractor/suffix_table.cpp b/src/extractor/suffix_table.cpp index 10d54bbbe..f160fcd93 100644 --- a/src/extractor/suffix_table.cpp +++ b/src/extractor/suffix_table.cpp @@ -1,38 +1,17 @@ #include "extractor/suffix_table.hpp" -#include "util/lua_util.hpp" -#include "util/simple_logger.hpp" +#include "extractor/scripting_environment.hpp" #include -#include -#include - -#include -#include namespace osrm { namespace extractor { -SuffixTable::SuffixTable(lua_State *lua_state) +SuffixTable::SuffixTable(ScriptingEnvironment &scripting_environment) { - BOOST_ASSERT(lua_state != nullptr); - if (!util::luaFunctionExists(lua_state, "get_name_suffix_list")) - return; - - std::vector suffixes_vector; - try - { - // call lua profile to compute turn penalty - luabind::call_function( - lua_state, "get_name_suffix_list", boost::ref(suffixes_vector)); - } - catch (const luabind::error &er) - { - util::SimpleLogger().Write(logWARNING) << er.what(); - } - + std::vector suffixes_vector = scripting_environment.GetNameSuffixList(); for (auto &suffix : suffixes_vector) boost::algorithm::to_lower(suffix); suffix_set.insert(std::begin(suffixes_vector), std::end(suffixes_vector)); diff --git a/src/tools/extract.cpp b/src/tools/extract.cpp index cc5915c3b..a41942780 100644 --- a/src/tools/extract.cpp +++ b/src/tools/extract.cpp @@ -1,5 +1,6 @@ #include "extractor/extractor.hpp" #include "extractor/extractor_config.hpp" +#include "extractor/scripting_environment_lua.hpp" #include "util/simple_logger.hpp" #include "util/version.hpp" @@ -147,7 +148,11 @@ int main(int argc, char *argv[]) try << "Profile " << extractor_config.profile_path.string() << " not found!"; return EXIT_FAILURE; } - return extractor::Extractor(extractor_config).run(); + + // setup scripting environment + extractor::LuaScriptingEnvironment scripting_environment( + extractor_config.profile_path.string().c_str()); + return extractor::Extractor(extractor_config).run(scripting_environment); } catch (const std::bad_alloc &e) {