Switch profiles from Lua to library interface

There's now an abstracted interface and no direct calls to Lua anymore.

fixes #1974
This commit is contained in:
Konstantin Käfer 2016-07-11 17:44:58 +02:00 committed by Patrick Niklaus
parent 9b737230d6
commit 1309dd2a0f
No known key found for this signature in database
GPG Key ID: E426891B5F978B1B
15 changed files with 382 additions and 245 deletions

View File

@ -36,13 +36,13 @@
#include <boost/filesystem/fstream.hpp> #include <boost/filesystem/fstream.hpp>
struct lua_State;
namespace osrm namespace osrm
{ {
namespace extractor namespace extractor
{ {
class ScriptingEnvironment;
namespace lookup namespace lookup
{ {
// Set to 1 byte alignment // Set to 1 byte alignment
@ -95,9 +95,9 @@ class EdgeBasedGraphFactory
const std::vector<std::uint32_t> &turn_lane_offsets, const std::vector<std::uint32_t> &turn_lane_offsets,
const std::vector<guidance::TurnLaneType::Mask> &turn_lane_masks); const std::vector<guidance::TurnLaneType::Mask> &turn_lane_masks);
void Run(const std::string &original_edge_data_filename, void Run(ScriptingEnvironment &scripting_environment,
const std::string &original_edge_data_filename,
const std::string &turn_lane_data_filename, const std::string &turn_lane_data_filename,
lua_State *lua_state,
const std::string &edge_segment_lookup_filename, const std::string &edge_segment_lookup_filename,
const std::string &edge_penalty_filename, const std::string &edge_penalty_filename,
const bool generate_edge_lookup); const bool generate_edge_lookup);
@ -127,8 +127,6 @@ class EdgeBasedGraphFactory
const NodeID w, const NodeID w,
const double angle) const; const double angle) const;
std::int32_t GetTurnPenalty(double angle, lua_State *lua_state) const;
private: private:
using EdgeData = util::NodeBasedDynamicGraph::EdgeData; using EdgeData = util::NodeBasedDynamicGraph::EdgeData;
@ -162,9 +160,9 @@ class EdgeBasedGraphFactory
void CompressGeometry(); void CompressGeometry();
unsigned RenumberEdges(); unsigned RenumberEdges();
void GenerateEdgeExpandedNodes(); void GenerateEdgeExpandedNodes();
void GenerateEdgeExpandedEdges(const std::string &original_edge_data_filename, void GenerateEdgeExpandedEdges(ScriptingEnvironment &scripting_environment,
const std::string &original_edge_data_filename,
const std::string &turn_lane_data_filename, const std::string &turn_lane_data_filename,
lua_State *lua_state,
const std::string &edge_segment_lookup_filename, const std::string &edge_segment_lookup_filename,
const std::string &edge_fixed_penalties_filename, const std::string &edge_fixed_penalties_filename,
const bool generate_edge_lookup); const bool generate_edge_lookup);

View File

@ -34,7 +34,7 @@ class ExtractionContainers
#endif #endif
void PrepareNodes(); void PrepareNodes();
void PrepareRestrictions(); void PrepareRestrictions();
void PrepareEdges(lua_State *segment_state); void PrepareEdges(ScriptingEnvironment &scripting_environment);
void WriteNodes(std::ofstream &file_out_stream) const; void WriteNodes(std::ofstream &file_out_stream) const;
void WriteRestrictions(const std::string &restrictions_file_name) const; void WriteRestrictions(const std::string &restrictions_file_name) const;
@ -69,11 +69,11 @@ class ExtractionContainers
ExtractionContainers(); ExtractionContainers();
void PrepareData(const std::string &output_file_name, void PrepareData(ScriptingEnvironment &scripting_environment,
const std::string &output_file_name,
const std::string &restrictions_file_name, const std::string &restrictions_file_name,
const std::string &names_file_name, const std::string &names_file_name,
const std::string &turn_lane_file_name, const std::string &turn_lane_file_name);
lua_State *segment_state);
}; };
} }
} }

View File

@ -43,20 +43,20 @@ namespace osrm
namespace extractor namespace extractor
{ {
class ScriptingEnvironment;
struct ProfileProperties; struct ProfileProperties;
class Extractor class Extractor
{ {
public: public:
Extractor(ExtractorConfig extractor_config) : config(std::move(extractor_config)) {} Extractor(ExtractorConfig extractor_config) : config(std::move(extractor_config)) {}
int run(); int run(ScriptingEnvironment &scripting_environment);
private: private:
ExtractorConfig config; ExtractorConfig config;
std::pair<std::size_t, EdgeID> std::pair<std::size_t, EdgeID>
BuildEdgeExpandedGraph(lua_State *lua_state, BuildEdgeExpandedGraph(ScriptingEnvironment &scripting_environment,
const ProfileProperties &profile_properties,
std::vector<QueryNode> &internal_to_external_node_map, std::vector<QueryNode> &internal_to_external_node_map,
std::vector<EdgeBasedNode> &node_based_edge_list, std::vector<EdgeBasedNode> &node_based_edge_list,
std::vector<bool> &node_is_startpoint, std::vector<bool> &node_is_startpoint,

View File

@ -77,7 +77,6 @@ struct ExtractorConfig
intersection_class_data_output_path = basepath + ".osrm.icd"; intersection_class_data_output_path = basepath + ".osrm.icd";
} }
boost::filesystem::path config_file_path;
boost::filesystem::path input_path; boost::filesystem::path input_path;
boost::filesystem::path profile_path; boost::filesystem::path profile_path;

View File

@ -8,7 +8,6 @@
#include <string> #include <string>
#include <vector> #include <vector>
struct lua_State;
namespace osmium namespace osmium
{ {
class Relation; class Relation;
@ -19,7 +18,7 @@ namespace osrm
namespace extractor namespace extractor
{ {
struct ProfileProperties; class ScriptingEnvironment;
/** /**
* Parses the relations that represents turn restrictions. * Parses the relations that represents turn restrictions.
@ -42,11 +41,10 @@ struct ProfileProperties;
class RestrictionParser class RestrictionParser
{ {
public: public:
RestrictionParser(lua_State *lua_state, const ProfileProperties &properties); RestrictionParser(ScriptingEnvironment &scripting_environment);
boost::optional<InputRestrictionContainer> TryParse(const osmium::Relation &relation) const; boost::optional<InputRestrictionContainer> TryParse(const osmium::Relation &relation) const;
private: private:
void ReadRestrictionExceptions(lua_State *lua_state);
bool ShouldIgnoreRestriction(const std::string &except_tag_string) const; bool ShouldIgnoreRestriction(const std::string &except_tag_string) const;
std::vector<std::string> restriction_exceptions; std::vector<std::string> restriction_exceptions;

View File

@ -1,52 +1,70 @@
#ifndef SCRIPTING_ENVIRONMENT_HPP #ifndef SCRIPTING_ENVIRONMENT_HPP
#define SCRIPTING_ENVIRONMENT_HPP #define SCRIPTING_ENVIRONMENT_HPP
#include "extractor/guidance/turn_lane_types.hpp"
#include "extractor/internal_extractor_edge.hpp"
#include "extractor/profile_properties.hpp" #include "extractor/profile_properties.hpp"
#include "extractor/raster_source.hpp" #include "extractor/restriction.hpp"
#include "util/lua_util.hpp" #include <osmium/memory/buffer.hpp>
#include <boost/optional/optional.hpp>
#include <tbb/concurrent_vector.h>
#include <memory>
#include <mutex>
#include <string> #include <string>
#include <tbb/enumerable_thread_specific.h> #include <vector>
struct lua_State; namespace osmium
{
class Node;
class Way;
}
namespace osrm namespace osrm
{ {
namespace util
{
struct Coordinate;
}
namespace extractor namespace extractor
{ {
class RestrictionParser;
struct ExtractionNode;
struct ExtractionWay;
/** /**
* Creates a lua context and binds osmium way, node and relation objects and * Abstract class that handles processing osmium ways, nodes and relation objects by applying
* ExtractionWay and ExtractionNode to lua objects. * user supplied profiles.
*
* Each thread has its own lua state which is implemented with thread specific
* storage from TBB.
*/ */
class ScriptingEnvironment class ScriptingEnvironment
{ {
public: public:
struct Context ScriptingEnvironment() = default;
{
ProfileProperties properties;
SourceContainer sources;
util::LuaState state;
};
explicit ScriptingEnvironment(const std::string &file_name);
ScriptingEnvironment(const ScriptingEnvironment &) = delete; ScriptingEnvironment(const ScriptingEnvironment &) = delete;
ScriptingEnvironment &operator=(const ScriptingEnvironment &) = delete; ScriptingEnvironment &operator=(const ScriptingEnvironment &) = delete;
virtual ~ScriptingEnvironment() = default;
Context &GetContex(); virtual const ProfileProperties &GetProfileProperties() = 0;
private: virtual std::vector<std::string> GetNameSuffixList() = 0;
void InitContext(Context &context); virtual std::vector<std::string> GetExceptions() = 0;
std::mutex init_mutex; virtual void SetupSources() = 0;
std::string file_name; virtual int32_t GetTurnPenalty(double angle) = 0;
tbb::enumerable_thread_specific<std::unique_ptr<Context>> script_contexts; virtual void ProcessSegment(const osrm::util::Coordinate &source,
const osrm::util::Coordinate &target,
double distance,
InternalExtractorEdge::WeightData &weight) = 0;
virtual void
ProcessElements(const std::vector<osmium::memory::Buffer::const_iterator> &osm_elements,
const RestrictionParser &restriction_parser,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionNode>> &resulting_nodes,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionWay>> &resulting_ways,
tbb::concurrent_vector<boost::optional<InputRestrictionContainer>>
&resulting_restrictions) = 0;
}; };
} }
} }

View File

@ -0,0 +1,80 @@
#ifndef SCRIPTING_ENVIRONMENT_LUA_HPP
#define SCRIPTING_ENVIRONMENT_LUA_HPP
#include "extractor/scripting_environment.hpp"
#include "extractor/raster_source.hpp"
#include "util/lua_util.hpp"
#include <tbb/enumerable_thread_specific.h>
#include <memory>
#include <mutex>
#include <string>
struct lua_State;
namespace osrm
{
namespace extractor
{
struct LuaScriptingContext final
{
void processNode(const osmium::Node &, ExtractionNode &result);
void processWay(const osmium::Way &, ExtractionWay &result);
ProfileProperties properties;
SourceContainer sources;
util::LuaState state;
bool has_turn_penalty_function;
bool has_node_function;
bool has_way_function;
bool has_segment_function;
};
/**
* Creates a lua context and binds osmium way, node and relation objects and
* ExtractionWay and ExtractionNode to lua objects.
*
* Each thread has its own lua state which is implemented with thread specific
* storage from TBB.
*/
class LuaScriptingEnvironment final : public ScriptingEnvironment
{
public:
explicit LuaScriptingEnvironment(const std::string &file_name);
~LuaScriptingEnvironment() override = default;
const ProfileProperties& GetProfileProperties() override;
LuaScriptingContext &GetLuaContext();
std::vector<std::string> GetNameSuffixList() override;
std::vector<std::string> GetExceptions() override;
void SetupSources() override;
int32_t GetTurnPenalty(double angle) override;
void ProcessSegment(const osrm::util::Coordinate &source,
const osrm::util::Coordinate &target,
double distance,
InternalExtractorEdge::WeightData &weight) override;
void
ProcessElements(const std::vector<osmium::memory::Buffer::const_iterator> &osm_elements,
const RestrictionParser &restriction_parser,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionNode>> &resulting_nodes,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionWay>> &resulting_ways,
tbb::concurrent_vector<boost::optional<InputRestrictionContainer>>
&resulting_restrictions) override;
private:
void InitContext(LuaScriptingContext &context);
std::mutex init_mutex;
std::string file_name;
tbb::enumerable_thread_specific<std::unique_ptr<LuaScriptingContext>> script_contexts;
};
}
}
#endif /* SCRIPTING_ENVIRONMENT_LUA_HPP */

View File

@ -4,19 +4,20 @@
#include <string> #include <string>
#include <unordered_set> #include <unordered_set>
struct lua_State;
namespace osrm namespace osrm
{ {
namespace extractor namespace extractor
{ {
class ScriptingEnvironment;
// A table containing suffixes. // A table containing suffixes.
// At the moment, it is only a front for an unordered set. At some point we might want to make it // At the moment, it is only a front for an unordered set. At some point we might want to make it
// country dependent and have it behave accordingly // country dependent and have it behave accordingly
class SuffixTable final class SuffixTable final
{ {
public: public:
SuffixTable(lua_State *lua_state); SuffixTable(ScriptingEnvironment &scripting_environment);
// check whether a string is part of the know suffix list // check whether a string is part of the know suffix list
bool isSuffix(const std::string &possible_suffix) const; bool isSuffix(const std::string &possible_suffix) const;

View File

@ -4,7 +4,6 @@
#include "util/coordinate_calculation.hpp" #include "util/coordinate_calculation.hpp"
#include "util/exception.hpp" #include "util/exception.hpp"
#include "util/integer_range.hpp" #include "util/integer_range.hpp"
#include "util/lua_util.hpp"
#include "util/percent.hpp" #include "util/percent.hpp"
#include "util/simple_logger.hpp" #include "util/simple_logger.hpp"
#include "util/timing_util.hpp" #include "util/timing_util.hpp"
@ -12,6 +11,7 @@
#include "extractor/guidance/toolkit.hpp" #include "extractor/guidance/toolkit.hpp"
#include "extractor/guidance/turn_analysis.hpp" #include "extractor/guidance/turn_analysis.hpp"
#include "extractor/guidance/turn_lane_handler.hpp" #include "extractor/guidance/turn_lane_handler.hpp"
#include "extractor/scripting_environment.hpp"
#include "extractor/suffix_table.hpp" #include "extractor/suffix_table.hpp"
#include <boost/assert.hpp> #include <boost/assert.hpp>
@ -182,9 +182,9 @@ void EdgeBasedGraphFactory::FlushVectorToStream(
original_edge_data_vector.clear(); original_edge_data_vector.clear();
} }
void EdgeBasedGraphFactory::Run(const std::string &original_edge_data_filename, void EdgeBasedGraphFactory::Run(ScriptingEnvironment &scripting_environment,
const std::string &original_edge_data_filename,
const std::string &turn_lane_data_filename, const std::string &turn_lane_data_filename,
lua_State *lua_state,
const std::string &edge_segment_lookup_filename, const std::string &edge_segment_lookup_filename,
const std::string &edge_penalty_filename, const std::string &edge_penalty_filename,
const bool generate_edge_lookup) const bool generate_edge_lookup)
@ -199,9 +199,9 @@ void EdgeBasedGraphFactory::Run(const std::string &original_edge_data_filename,
TIMER_STOP(generate_nodes); TIMER_STOP(generate_nodes);
TIMER_START(generate_edges); TIMER_START(generate_edges);
GenerateEdgeExpandedEdges(original_edge_data_filename, GenerateEdgeExpandedEdges(scripting_environment,
original_edge_data_filename,
turn_lane_data_filename, turn_lane_data_filename,
lua_state,
edge_segment_lookup_filename, edge_segment_lookup_filename,
edge_penalty_filename, edge_penalty_filename,
generate_edge_lookup); generate_edge_lookup);
@ -298,18 +298,15 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedNodes()
/// Actually it also generates OriginalEdgeData and serializes them... /// Actually it also generates OriginalEdgeData and serializes them...
void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges( void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges(
ScriptingEnvironment &scripting_environment,
const std::string &original_edge_data_filename, const std::string &original_edge_data_filename,
const std::string &turn_lane_data_filename, const std::string &turn_lane_data_filename,
lua_State *lua_state,
const std::string &edge_segment_lookup_filename, const std::string &edge_segment_lookup_filename,
const std::string &edge_fixed_penalties_filename, const std::string &edge_fixed_penalties_filename,
const bool generate_edge_lookup) const bool generate_edge_lookup)
{ {
util::SimpleLogger().Write() << "generating edge-expanded edges"; util::SimpleLogger().Write() << "generating edge-expanded edges";
BOOST_ASSERT(lua_state != nullptr);
const bool use_turn_function = util::luaFunctionExists(lua_state, "turn_function");
std::size_t node_based_edge_counter = 0; std::size_t node_based_edge_counter = 0;
std::size_t original_edges_counter = 0; std::size_t original_edges_counter = 0;
restricted_turns_counter = 0; restricted_turns_counter = 0;
@ -338,7 +335,7 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges(
// Three nested loop look super-linear, but we are dealing with a (kind of) // Three nested loop look super-linear, but we are dealing with a (kind of)
// linear number of turns only. // linear number of turns only.
util::Percent progress(m_node_based_graph->GetNumberOfNodes()); util::Percent progress(m_node_based_graph->GetNumberOfNodes());
SuffixTable street_name_suffix_table(lua_state); SuffixTable street_name_suffix_table(scripting_environment);
guidance::TurnAnalysis turn_analysis(*m_node_based_graph, guidance::TurnAnalysis turn_analysis(*m_node_based_graph,
m_node_info_list, m_node_info_list,
*m_restriction_map, *m_restriction_map,
@ -410,8 +407,6 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges(
for (const auto turn : possible_turns) for (const auto turn : possible_turns)
{ {
const double turn_angle = turn.angle;
// only add an edge if turn is not prohibited // only add an edge if turn is not prohibited
const EdgeData &edge_data1 = m_node_based_graph->GetEdgeData(edge_from_u); const EdgeData &edge_data1 = m_node_based_graph->GetEdgeData(edge_from_u);
const EdgeData &edge_data2 = m_node_based_graph->GetEdgeData(turn.eid); const EdgeData &edge_data2 = m_node_based_graph->GetEdgeData(turn.eid);
@ -427,8 +422,7 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges(
distance += profile_properties.traffic_signal_penalty; distance += profile_properties.traffic_signal_penalty;
} }
const int turn_penalty = const int32_t turn_penalty = scripting_environment.GetTurnPenalty(180. - turn.angle);
use_turn_function ? GetTurnPenalty(turn_angle, lua_state) : 0;
const auto turn_instruction = turn.instruction; const auto turn_instruction = turn.instruction;
if (guidance::isUturn(turn_instruction)) if (guidance::isUturn(turn_instruction))
@ -615,23 +609,5 @@ std::vector<util::guidance::EntryClass> EdgeBasedGraphFactory::GetEntryClasses()
return result; return result;
} }
int EdgeBasedGraphFactory::GetTurnPenalty(double angle, lua_State *lua_state) const
{
BOOST_ASSERT(lua_state != nullptr);
try
{
// call lua profile to compute turn penalty
double penalty = luabind::call_function<double>(lua_state, "turn_function", 180. - angle);
BOOST_ASSERT(penalty < std::numeric_limits<int>::max());
BOOST_ASSERT(penalty > std::numeric_limits<int>::min());
return boost::numeric_cast<int>(penalty);
}
catch (const luabind::error &er)
{
util::SimpleLogger().Write(logWARNING) << er.what();
}
return 0;
}
} // namespace extractor } // namespace extractor
} // namespace osrm } // namespace osrm

View File

@ -7,7 +7,6 @@
#include "util/exception.hpp" #include "util/exception.hpp"
#include "util/fingerprint.hpp" #include "util/fingerprint.hpp"
#include "util/io.hpp" #include "util/io.hpp"
#include "util/lua_util.hpp"
#include "util/simple_logger.hpp" #include "util/simple_logger.hpp"
#include "util/timing_util.hpp" #include "util/timing_util.hpp"
@ -17,8 +16,6 @@
#include <boost/numeric/conversion/cast.hpp> #include <boost/numeric/conversion/cast.hpp>
#include <boost/ref.hpp> #include <boost/ref.hpp>
#include <luabind/luabind.hpp>
#include <stxxl/sort> #include <stxxl/sort>
#include <chrono> #include <chrono>
@ -137,11 +134,11 @@ ExtractionContainers::ExtractionContainers()
* - merge edges with nodes to include location of start/end points and serialize * - merge edges with nodes to include location of start/end points and serialize
* *
*/ */
void ExtractionContainers::PrepareData(const std::string &output_file_name, void ExtractionContainers::PrepareData(ScriptingEnvironment &scripting_environment,
const std::string &output_file_name,
const std::string &restrictions_file_name, const std::string &restrictions_file_name,
const std::string &name_file_name, const std::string &name_file_name,
const std::string &turn_lane_file_name, const std::string &turn_lane_file_name)
lua_State *segment_state)
{ {
try try
{ {
@ -152,7 +149,7 @@ void ExtractionContainers::PrepareData(const std::string &output_file_name,
PrepareNodes(); PrepareNodes();
WriteNodes(file_out_stream); WriteNodes(file_out_stream);
PrepareEdges(segment_state); PrepareEdges(scripting_environment);
WriteEdges(file_out_stream); WriteEdges(file_out_stream);
PrepareRestrictions(); PrepareRestrictions();
@ -304,7 +301,7 @@ void ExtractionContainers::PrepareNodes()
std::cout << "ok, after " << TIMER_SEC(id_map) << "s" << std::endl; std::cout << "ok, after " << TIMER_SEC(id_map) << "s" << std::endl;
} }
void ExtractionContainers::PrepareEdges(lua_State *segment_state) void ExtractionContainers::PrepareEdges(ScriptingEnvironment &scripting_environment)
{ {
// Sort edges by start. // Sort edges by start.
std::cout << "[extractor] Sorting edges by start ... " << std::flush; std::cout << "[extractor] Sorting edges by start ... " << std::flush;
@ -386,8 +383,6 @@ void ExtractionContainers::PrepareEdges(lua_State *segment_state)
const auto all_edges_list_end_ = all_edges_list.end(); const auto all_edges_list_end_ = all_edges_list.end();
const auto all_nodes_list_end_ = all_nodes_list.end(); const auto all_nodes_list_end_ = all_nodes_list.end();
const auto has_segment_function = util::luaFunctionExists(segment_state, "segment_function");
while (edge_iterator != all_edges_list_end_ && node_iterator != all_nodes_list_end_) while (edge_iterator != all_edges_list_end_ && node_iterator != all_nodes_list_end_)
{ {
// skip all invalid edges // skip all invalid edges
@ -423,15 +418,8 @@ void ExtractionContainers::PrepareEdges(lua_State *segment_state)
edge_iterator->source_coordinate, edge_iterator->source_coordinate,
util::Coordinate(node_iterator->lon, node_iterator->lat)); util::Coordinate(node_iterator->lon, node_iterator->lat));
if (has_segment_function) scripting_environment.ProcessSegment(
{ edge_iterator->source_coordinate, *node_iterator, distance, edge_iterator->weight_data);
luabind::call_function<void>(segment_state,
"segment_function",
boost::cref(edge_iterator->source_coordinate),
boost::cref(*node_iterator),
distance,
boost::ref(edge_iterator->weight_data));
}
const double weight = [distance](const InternalExtractorEdge::WeightData &data) { const double weight = [distance](const InternalExtractorEdge::WeightData &data) {
switch (data.type) switch (data.type)

View File

@ -11,7 +11,6 @@
#include "extractor/raster_source.hpp" #include "extractor/raster_source.hpp"
#include "util/graph_loader.hpp" #include "util/graph_loader.hpp"
#include "util/io.hpp" #include "util/io.hpp"
#include "util/lua_util.hpp"
#include "util/make_unique.hpp" #include "util/make_unique.hpp"
#include "util/name_table.hpp" #include "util/name_table.hpp"
#include "util/range_table.hpp" #include "util/range_table.hpp"
@ -29,11 +28,9 @@
#include <boost/filesystem/fstream.hpp> #include <boost/filesystem/fstream.hpp>
#include <boost/optional/optional.hpp> #include <boost/optional/optional.hpp>
#include <luabind/luabind.hpp>
#include <osmium/io/any_input.hpp> #include <osmium/io/any_input.hpp>
#include <tbb/parallel_for.h> #include <tbb/concurrent_vector.h>
#include <tbb/task_scheduler_init.h> #include <tbb/task_scheduler_init.h>
#include <cstdlib> #include <cstdlib>
@ -74,11 +71,8 @@ namespace extractor
* graph * graph
* *
*/ */
int Extractor::run() int Extractor::run(ScriptingEnvironment &scripting_environment)
{ {
// setup scripting environment
ScriptingEnvironment scripting_environment(config.profile_path.string().c_str());
try try
{ {
util::LogPolicy::GetInstance().Unmute(); util::LogPolicy::GetInstance().Unmute();
@ -90,7 +84,9 @@ int Extractor::run()
tbb::task_scheduler_init init(number_of_threads); tbb::task_scheduler_init init(number_of_threads);
util::SimpleLogger().Write() << "Input file: " << config.input_path.filename().string(); util::SimpleLogger().Write() << "Input file: " << config.input_path.filename().string();
if (!config.profile_path.empty()) {
util::SimpleLogger().Write() << "Profile: " << config.profile_path.filename().string(); util::SimpleLogger().Write() << "Profile: " << config.profile_path.filename().string();
}
util::SimpleLogger().Write() << "Threads: " << number_of_threads; util::SimpleLogger().Write() << "Threads: " << number_of_threads;
ExtractionContainers extraction_containers; ExtractionContainers extraction_containers;
@ -100,21 +96,15 @@ int Extractor::run()
osmium::io::Reader reader(input_file); osmium::io::Reader reader(input_file);
const osmium::io::Header header = reader.header(); const osmium::io::Header header = reader.header();
std::atomic<unsigned> number_of_nodes{0}; unsigned number_of_nodes = 0;
std::atomic<unsigned> number_of_ways{0}; unsigned number_of_ways = 0;
std::atomic<unsigned> number_of_relations{0}; unsigned number_of_relations = 0;
std::atomic<unsigned> number_of_others{0};
util::SimpleLogger().Write() << "Parsing in progress.."; util::SimpleLogger().Write() << "Parsing in progress..";
TIMER_START(parsing); TIMER_START(parsing);
auto &main_context = scripting_environment.GetContex();
// setup raster sources // setup raster sources
if (util::luaFunctionExists(main_context.state, "source_function")) scripting_environment.SetupSources();
{
luabind::call_function<void>(main_context.state, "source_function");
}
std::string generator = header.get("generator"); std::string generator = header.get("generator");
if (generator.empty()) if (generator.empty())
@ -140,7 +130,7 @@ int Extractor::run()
tbb::concurrent_vector<boost::optional<InputRestrictionContainer>> resulting_restrictions; tbb::concurrent_vector<boost::optional<InputRestrictionContainer>> resulting_restrictions;
// setup restriction parser // setup restriction parser
const RestrictionParser restriction_parser(main_context.state, main_context.properties); const RestrictionParser restriction_parser(scripting_environment);
while (const osmium::memory::Buffer buffer = reader.read()) while (const osmium::memory::Buffer buffer = reader.read())
{ {
@ -156,52 +146,14 @@ int Extractor::run()
resulting_ways.clear(); resulting_ways.clear();
resulting_restrictions.clear(); resulting_restrictions.clear();
// parse OSM entities in parallel, store in resulting vectors scripting_environment.ProcessElements(osm_elements,
tbb::parallel_for( restriction_parser,
tbb::blocked_range<std::size_t>(0, osm_elements.size()), resulting_nodes,
[&](const tbb::blocked_range<std::size_t> &range) { resulting_ways,
ExtractionNode result_node; resulting_restrictions);
ExtractionWay result_way;
auto &local_context = scripting_environment.GetContex();
for (auto x = range.begin(), end = range.end(); x != end; ++x)
{
const auto entity = osm_elements[x];
switch (entity->type())
{
case osmium::item_type::node:
result_node.clear();
++number_of_nodes;
luabind::call_function<void>(
local_context.state,
"node_function",
boost::cref(static_cast<const osmium::Node &>(*entity)),
boost::ref(result_node));
resulting_nodes.push_back(std::make_pair(x, std::move(result_node)));
break;
case osmium::item_type::way:
result_way.clear();
++number_of_ways;
luabind::call_function<void>(
local_context.state,
"way_function",
boost::cref(static_cast<const osmium::Way &>(*entity)),
boost::ref(result_way));
resulting_ways.push_back(std::make_pair(x, std::move(result_way)));
break;
case osmium::item_type::relation:
++number_of_relations;
resulting_restrictions.push_back(restriction_parser.TryParse(
static_cast<const osmium::Relation &>(*entity)));
break;
default:
++number_of_others;
break;
}
}
});
number_of_nodes += resulting_nodes.size();
// put parsed objects thru extractor callbacks // put parsed objects thru extractor callbacks
for (const auto &result : resulting_nodes) for (const auto &result : resulting_nodes)
{ {
@ -209,11 +161,13 @@ int Extractor::run()
static_cast<const osmium::Node &>(*(osm_elements[result.first])), static_cast<const osmium::Node &>(*(osm_elements[result.first])),
result.second); result.second);
} }
number_of_ways += resulting_ways.size();
for (const auto &result : resulting_ways) for (const auto &result : resulting_ways)
{ {
extractor_callbacks->ProcessWay( extractor_callbacks->ProcessWay(
static_cast<const osmium::Way &>(*(osm_elements[result.first])), result.second); static_cast<const osmium::Way &>(*(osm_elements[result.first])), result.second);
} }
number_of_relations += resulting_restrictions.size();
for (const auto &result : resulting_restrictions) for (const auto &result : resulting_restrictions)
{ {
extractor_callbacks->ProcessRestriction(result); extractor_callbacks->ProcessRestriction(result);
@ -223,10 +177,9 @@ int Extractor::run()
util::SimpleLogger().Write() << "Parsing finished after " << TIMER_SEC(parsing) util::SimpleLogger().Write() << "Parsing finished after " << TIMER_SEC(parsing)
<< " seconds"; << " seconds";
util::SimpleLogger().Write() << "Raw input contains " << number_of_nodes.load() util::SimpleLogger().Write() << "Raw input contains " << number_of_nodes << " nodes, "
<< " nodes, " << number_of_ways.load() << " ways, and " << number_of_ways << " ways, and " << number_of_relations
<< number_of_relations.load() << " relations, and " << " relations";
<< number_of_others.load() << " unknown entities";
extractor_callbacks.reset(); extractor_callbacks.reset();
@ -236,13 +189,14 @@ int Extractor::run()
return 1; return 1;
} }
extraction_containers.PrepareData(config.output_file_name, extraction_containers.PrepareData(scripting_environment,
config.output_file_name,
config.restriction_file_name, config.restriction_file_name,
config.names_file_name, config.names_file_name,
config.turn_lane_descriptions_file_name, config.turn_lane_descriptions_file_name);
main_context.state);
WriteProfileProperties(config.profile_properties_output_path, main_context.properties); WriteProfileProperties(config.profile_properties_output_path,
scripting_environment.GetProfileProperties());
TIMER_STOP(extracting); TIMER_STOP(extracting);
util::SimpleLogger().Write() << "extraction finished after " << TIMER_SEC(extracting) util::SimpleLogger().Write() << "extraction finished after " << TIMER_SEC(extracting)
@ -261,9 +215,6 @@ int Extractor::run()
// that is better for routing. Every edge becomes a node, and every valid // that is better for routing. Every edge becomes a node, and every valid
// movement (e.g. turn from A->B, and B->A) becomes an edge // movement (e.g. turn from A->B, and B->A) becomes an edge
// //
auto &main_context = scripting_environment.GetContex();
util::SimpleLogger().Write() << "Generating edge-expanded graph representation"; util::SimpleLogger().Write() << "Generating edge-expanded graph representation";
TIMER_START(expansion); TIMER_START(expansion);
@ -273,8 +224,7 @@ int Extractor::run()
std::vector<bool> node_is_startpoint; std::vector<bool> node_is_startpoint;
std::vector<EdgeWeight> edge_based_node_weights; std::vector<EdgeWeight> edge_based_node_weights;
std::vector<QueryNode> internal_to_external_node_map; std::vector<QueryNode> internal_to_external_node_map;
auto graph_size = BuildEdgeExpandedGraph(main_context.state, auto graph_size = BuildEdgeExpandedGraph(scripting_environment,
main_context.properties,
internal_to_external_node_map, internal_to_external_node_map,
edge_based_node_list, edge_based_node_list,
node_is_startpoint, node_is_startpoint,
@ -477,8 +427,7 @@ Extractor::LoadNodeBasedGraph(std::unordered_set<NodeID> &barrier_nodes,
\brief Building an edge-expanded graph from node-based input and turn restrictions \brief Building an edge-expanded graph from node-based input and turn restrictions
*/ */
std::pair<std::size_t, EdgeID> std::pair<std::size_t, EdgeID>
Extractor::BuildEdgeExpandedGraph(lua_State *lua_state, Extractor::BuildEdgeExpandedGraph(ScriptingEnvironment &scripting_environment,
const ProfileProperties &profile_properties,
std::vector<QueryNode> &internal_to_external_node_map, std::vector<QueryNode> &internal_to_external_node_map,
std::vector<EdgeBasedNode> &node_based_edge_list, std::vector<EdgeBasedNode> &node_based_edge_list,
std::vector<bool> &node_is_startpoint, std::vector<bool> &node_is_startpoint,
@ -520,14 +469,14 @@ Extractor::BuildEdgeExpandedGraph(lua_State *lua_state,
traffic_lights, traffic_lights,
std::const_pointer_cast<RestrictionMap const>(restriction_map), std::const_pointer_cast<RestrictionMap const>(restriction_map),
internal_to_external_node_map, internal_to_external_node_map,
profile_properties, scripting_environment.GetProfileProperties(),
name_table, name_table,
turn_lane_offsets, turn_lane_offsets,
turn_lane_masks); turn_lane_masks);
edge_based_graph_factory.Run(config.edge_output_path, edge_based_graph_factory.Run(scripting_environment,
config.edge_output_path,
config.turn_lane_data_file_name, config.turn_lane_data_file_name,
lua_state,
config.edge_segment_lookup_path, config.edge_segment_lookup_path,
config.edge_penalty_path, config.edge_penalty_path,
config.generate_edge_lookup); config.generate_edge_lookup);

View File

@ -1,9 +1,9 @@
#include "extractor/restriction_parser.hpp" #include "extractor/restriction_parser.hpp"
#include "extractor/profile_properties.hpp" #include "extractor/profile_properties.hpp"
#include "extractor/scripting_environment.hpp"
#include "extractor/external_memory_node.hpp" #include "extractor/external_memory_node.hpp"
#include "util/exception.hpp"
#include "util/lua_util.hpp"
#include "util/simple_logger.hpp" #include "util/simple_logger.hpp"
#include <boost/algorithm/string.hpp> #include <boost/algorithm/string.hpp>
@ -24,33 +24,15 @@ namespace osrm
namespace extractor namespace extractor
{ {
namespace RestrictionParser::RestrictionParser(ScriptingEnvironment &scripting_environment)
{ : use_turn_restrictions(scripting_environment.GetProfileProperties().use_turn_restrictions)
int luaErrorCallback(lua_State *lua_state)
{
std::string error_msg = lua_tostring(lua_state, -1);
throw util::exception("ERROR occurred in profile script:\n" + error_msg);
}
}
RestrictionParser::RestrictionParser(lua_State *lua_state, const ProfileProperties &properties)
: use_turn_restrictions(properties.use_turn_restrictions)
{ {
if (use_turn_restrictions) if (use_turn_restrictions)
{ {
ReadRestrictionExceptions(lua_state); restriction_exceptions = scripting_environment.GetExceptions();
}
}
void RestrictionParser::ReadRestrictionExceptions(lua_State *lua_state)
{
if (util::luaFunctionExists(lua_state, "get_exceptions"))
{
luabind::set_pcall_callback(&luaErrorCallback);
// get list of turn restriction exceptions
luabind::call_function<void>(
lua_state, "get_exceptions", boost::ref(restriction_exceptions));
const unsigned exception_count = restriction_exceptions.size(); const unsigned exception_count = restriction_exceptions.size();
if (exception_count)
{
util::SimpleLogger().Write() << "Found " << exception_count util::SimpleLogger().Write() << "Found " << exception_count
<< " exceptions to turn restrictions:"; << " exceptions to turn restrictions:";
for (const std::string &str : restriction_exceptions) for (const std::string &str : restriction_exceptions)
@ -62,6 +44,7 @@ void RestrictionParser::ReadRestrictionExceptions(lua_State *lua_state)
{ {
util::SimpleLogger().Write() << "Found no exceptions to turn restrictions"; util::SimpleLogger().Write() << "Found no exceptions to turn restrictions";
} }
}
} }
/** /**

View File

@ -1,4 +1,4 @@
#include "extractor/scripting_environment.hpp" #include "extractor/scripting_environment_lua.hpp"
#include "extractor/external_memory_node.hpp" #include "extractor/external_memory_node.hpp"
#include "extractor/extraction_helper_functions.hpp" #include "extractor/extraction_helper_functions.hpp"
@ -7,6 +7,7 @@
#include "extractor/internal_extractor_edge.hpp" #include "extractor/internal_extractor_edge.hpp"
#include "extractor/profile_properties.hpp" #include "extractor/profile_properties.hpp"
#include "extractor/raster_source.hpp" #include "extractor/raster_source.hpp"
#include "extractor/restriction_parser.hpp"
#include "util/exception.hpp" #include "util/exception.hpp"
#include "util/lua_util.hpp" #include "util/lua_util.hpp"
#include "util/make_unique.hpp" #include "util/make_unique.hpp"
@ -19,6 +20,8 @@
#include <osmium/osm.hpp> #include <osmium/osm.hpp>
#include <tbb/parallel_for.h>
#include <sstream> #include <sstream>
namespace osrm namespace osrm
@ -58,12 +61,13 @@ int luaErrorCallback(lua_State *state)
} }
} }
ScriptingEnvironment::ScriptingEnvironment(const std::string &file_name) : file_name(file_name) LuaScriptingEnvironment::LuaScriptingEnvironment(const std::string &file_name)
: file_name(file_name)
{ {
util::SimpleLogger().Write() << "Using script " << file_name; util::SimpleLogger().Write() << "Using script " << file_name;
} }
void ScriptingEnvironment::InitContext(ScriptingEnvironment::Context &context) void LuaScriptingEnvironment::InitContext(LuaScriptingContext &context)
{ {
typedef double (osmium::Location::*location_member_ptr_type)() const; typedef double (osmium::Location::*location_member_ptr_type)() const;
@ -187,21 +191,180 @@ void ScriptingEnvironment::InitContext(ScriptingEnvironment::Context &context)
error_stream << error_msg; error_stream << error_msg;
throw util::exception("ERROR occurred in profile script:\n" + error_stream.str()); throw util::exception("ERROR occurred in profile script:\n" + error_stream.str());
} }
context.has_turn_penalty_function = util::luaFunctionExists(context.state, "turn_function");
context.has_node_function = util::luaFunctionExists(context.state, "node_function");
context.has_way_function = util::luaFunctionExists(context.state, "way_function");
context.has_segment_function = util::luaFunctionExists(context.state, "segment_function");
} }
ScriptingEnvironment::Context &ScriptingEnvironment::GetContex() const ProfileProperties &LuaScriptingEnvironment::GetProfileProperties()
{
return GetLuaContext().properties;
}
LuaScriptingContext &LuaScriptingEnvironment::GetLuaContext()
{ {
std::lock_guard<std::mutex> lock(init_mutex); std::lock_guard<std::mutex> lock(init_mutex);
bool initialized = false; bool initialized = false;
auto &ref = script_contexts.local(initialized); auto &ref = script_contexts.local(initialized);
if (!initialized) if (!initialized)
{ {
ref = util::make_unique<Context>(); ref = util::make_unique<LuaScriptingContext>();
InitContext(*ref); InitContext(*ref);
} }
luabind::set_pcall_callback(&luaErrorCallback); luabind::set_pcall_callback(&luaErrorCallback);
return *ref; return *ref;
} }
void LuaScriptingEnvironment::ProcessElements(
const std::vector<osmium::memory::Buffer::const_iterator> &osm_elements,
const RestrictionParser &restriction_parser,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionNode>> &resulting_nodes,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionWay>> &resulting_ways,
tbb::concurrent_vector<boost::optional<InputRestrictionContainer>> &resulting_restrictions)
{
// parse OSM entities in parallel, store in resulting vectors
tbb::parallel_for(
tbb::blocked_range<std::size_t>(0, osm_elements.size()),
[&](const tbb::blocked_range<std::size_t> &range) {
ExtractionNode result_node;
ExtractionWay result_way;
auto &local_context = this->GetLuaContext();
for (auto x = range.begin(), end = range.end(); x != end; ++x)
{
const auto entity = osm_elements[x];
switch (entity->type())
{
case osmium::item_type::node:
result_node.clear();
if (local_context.has_node_function)
{
local_context.processNode(static_cast<const osmium::Node &>(*entity),
result_node);
}
resulting_nodes.push_back(std::make_pair(x, std::move(result_node)));
break;
case osmium::item_type::way:
result_way.clear();
if (local_context.has_way_function)
{
local_context.processWay(static_cast<const osmium::Way &>(*entity),
result_way);
}
resulting_ways.push_back(std::make_pair(x, std::move(result_way)));
break;
case osmium::item_type::relation:
resulting_restrictions.push_back(restriction_parser.TryParse(
static_cast<const osmium::Relation &>(*entity)));
break;
default:
break;
}
}
});
}
std::vector<std::string> LuaScriptingEnvironment::GetNameSuffixList()
{
auto &context = GetLuaContext();
BOOST_ASSERT(context.state != nullptr);
if (!util::luaFunctionExists(context.state, "get_name_suffix_list"))
return {};
std::vector<std::string> suffixes_vector;
try
{
// call lua profile to compute turn penalty
luabind::call_function<void>(
context.state, "get_name_suffix_list", boost::ref(suffixes_vector));
}
catch (const luabind::error &er)
{
util::SimpleLogger().Write(logWARNING) << er.what();
}
return suffixes_vector;
}
std::vector<std::string> LuaScriptingEnvironment::GetExceptions()
{
auto &context = GetLuaContext();
BOOST_ASSERT(context.state != nullptr);
std::vector<std::string> restriction_exceptions;
if (util::luaFunctionExists(context.state, "get_exceptions"))
{
// get list of turn restriction exceptions
luabind::call_function<void>(
context.state, "get_exceptions", boost::ref(restriction_exceptions));
}
return restriction_exceptions;
}
void LuaScriptingEnvironment::SetupSources()
{
auto &context = GetLuaContext();
BOOST_ASSERT(context.state != nullptr);
if (util::luaFunctionExists(context.state, "source_function"))
{
luabind::call_function<void>(context.state, "source_function");
}
}
int32_t LuaScriptingEnvironment::GetTurnPenalty(const double angle)
{
auto &context = GetLuaContext();
if (context.has_turn_penalty_function)
{
BOOST_ASSERT(context.state != nullptr);
try
{
// call lua profile to compute turn penalty
const double penalty =
luabind::call_function<double>(context.state, "turn_function", angle);
BOOST_ASSERT(penalty < std::numeric_limits<int32_t>::max());
BOOST_ASSERT(penalty > std::numeric_limits<int32_t>::min());
return boost::numeric_cast<int32_t>(penalty);
}
catch (const luabind::error &er)
{
util::SimpleLogger().Write(logWARNING) << er.what();
}
}
return 0;
}
void LuaScriptingEnvironment::ProcessSegment(const osrm::util::Coordinate &source,
const osrm::util::Coordinate &target,
double distance,
InternalExtractorEdge::WeightData &weight)
{
auto &context = GetLuaContext();
if (context.has_segment_function)
{
BOOST_ASSERT(context.state != nullptr);
luabind::call_function<void>(context.state,
"segment_function",
boost::cref(source),
boost::cref(target),
distance,
boost::ref(weight));
}
}
void LuaScriptingContext::processNode(const osmium::Node &node, ExtractionNode &result)
{
BOOST_ASSERT(state != nullptr);
luabind::call_function<void>(state, "node_function", boost::cref(node), boost::ref(result));
}
void LuaScriptingContext::processWay(const osmium::Way &way, ExtractionWay &result)
{
BOOST_ASSERT(state != nullptr);
luabind::call_function<void>(state, "way_function", boost::cref(way), boost::ref(result));
}
} }
} }

View File

@ -1,38 +1,17 @@
#include "extractor/suffix_table.hpp" #include "extractor/suffix_table.hpp"
#include "util/lua_util.hpp" #include "extractor/scripting_environment.hpp"
#include "util/simple_logger.hpp"
#include <boost/algorithm/string.hpp> #include <boost/algorithm/string.hpp>
#include <boost/assert.hpp>
#include <boost/ref.hpp>
#include <iterator>
#include <vector>
namespace osrm namespace osrm
{ {
namespace extractor namespace extractor
{ {
SuffixTable::SuffixTable(lua_State *lua_state) SuffixTable::SuffixTable(ScriptingEnvironment &scripting_environment)
{ {
BOOST_ASSERT(lua_state != nullptr); std::vector<std::string> suffixes_vector = scripting_environment.GetNameSuffixList();
if (!util::luaFunctionExists(lua_state, "get_name_suffix_list"))
return;
std::vector<std::string> suffixes_vector;
try
{
// call lua profile to compute turn penalty
luabind::call_function<void>(
lua_state, "get_name_suffix_list", boost::ref(suffixes_vector));
}
catch (const luabind::error &er)
{
util::SimpleLogger().Write(logWARNING) << er.what();
}
for (auto &suffix : suffixes_vector) for (auto &suffix : suffixes_vector)
boost::algorithm::to_lower(suffix); boost::algorithm::to_lower(suffix);
suffix_set.insert(std::begin(suffixes_vector), std::end(suffixes_vector)); suffix_set.insert(std::begin(suffixes_vector), std::end(suffixes_vector));

View File

@ -1,5 +1,6 @@
#include "extractor/extractor.hpp" #include "extractor/extractor.hpp"
#include "extractor/extractor_config.hpp" #include "extractor/extractor_config.hpp"
#include "extractor/scripting_environment_lua.hpp"
#include "util/simple_logger.hpp" #include "util/simple_logger.hpp"
#include "util/version.hpp" #include "util/version.hpp"
@ -147,7 +148,11 @@ int main(int argc, char *argv[]) try
<< "Profile " << extractor_config.profile_path.string() << " not found!"; << "Profile " << extractor_config.profile_path.string() << " not found!";
return EXIT_FAILURE; return EXIT_FAILURE;
} }
return extractor::Extractor(extractor_config).run();
// setup scripting environment
extractor::LuaScriptingEnvironment scripting_environment(
extractor_config.profile_path.string().c_str());
return extractor::Extractor(extractor_config).run(scripting_environment);
} }
catch (const std::bad_alloc &e) catch (const std::bad_alloc &e)
{ {