Switch profiles from Lua to library interface

There's now an abstracted interface and no direct calls to Lua anymore.

fixes #1974
This commit is contained in:
Konstantin Käfer 2016-07-11 17:44:58 +02:00 committed by Patrick Niklaus
parent 9b737230d6
commit 1309dd2a0f
No known key found for this signature in database
GPG Key ID: E426891B5F978B1B
15 changed files with 382 additions and 245 deletions

View File

@ -36,13 +36,13 @@
#include <boost/filesystem/fstream.hpp>
struct lua_State;
namespace osrm
{
namespace extractor
{
class ScriptingEnvironment;
namespace lookup
{
// Set to 1 byte alignment
@ -95,9 +95,9 @@ class EdgeBasedGraphFactory
const std::vector<std::uint32_t> &turn_lane_offsets,
const std::vector<guidance::TurnLaneType::Mask> &turn_lane_masks);
void Run(const std::string &original_edge_data_filename,
void Run(ScriptingEnvironment &scripting_environment,
const std::string &original_edge_data_filename,
const std::string &turn_lane_data_filename,
lua_State *lua_state,
const std::string &edge_segment_lookup_filename,
const std::string &edge_penalty_filename,
const bool generate_edge_lookup);
@ -127,8 +127,6 @@ class EdgeBasedGraphFactory
const NodeID w,
const double angle) const;
std::int32_t GetTurnPenalty(double angle, lua_State *lua_state) const;
private:
using EdgeData = util::NodeBasedDynamicGraph::EdgeData;
@ -162,9 +160,9 @@ class EdgeBasedGraphFactory
void CompressGeometry();
unsigned RenumberEdges();
void GenerateEdgeExpandedNodes();
void GenerateEdgeExpandedEdges(const std::string &original_edge_data_filename,
void GenerateEdgeExpandedEdges(ScriptingEnvironment &scripting_environment,
const std::string &original_edge_data_filename,
const std::string &turn_lane_data_filename,
lua_State *lua_state,
const std::string &edge_segment_lookup_filename,
const std::string &edge_fixed_penalties_filename,
const bool generate_edge_lookup);

View File

@ -34,7 +34,7 @@ class ExtractionContainers
#endif
void PrepareNodes();
void PrepareRestrictions();
void PrepareEdges(lua_State *segment_state);
void PrepareEdges(ScriptingEnvironment &scripting_environment);
void WriteNodes(std::ofstream &file_out_stream) const;
void WriteRestrictions(const std::string &restrictions_file_name) const;
@ -69,11 +69,11 @@ class ExtractionContainers
ExtractionContainers();
void PrepareData(const std::string &output_file_name,
void PrepareData(ScriptingEnvironment &scripting_environment,
const std::string &output_file_name,
const std::string &restrictions_file_name,
const std::string &names_file_name,
const std::string &turn_lane_file_name,
lua_State *segment_state);
const std::string &turn_lane_file_name);
};
}
}

View File

@ -43,20 +43,20 @@ namespace osrm
namespace extractor
{
class ScriptingEnvironment;
struct ProfileProperties;
class Extractor
{
public:
Extractor(ExtractorConfig extractor_config) : config(std::move(extractor_config)) {}
int run();
int run(ScriptingEnvironment &scripting_environment);
private:
ExtractorConfig config;
std::pair<std::size_t, EdgeID>
BuildEdgeExpandedGraph(lua_State *lua_state,
const ProfileProperties &profile_properties,
BuildEdgeExpandedGraph(ScriptingEnvironment &scripting_environment,
std::vector<QueryNode> &internal_to_external_node_map,
std::vector<EdgeBasedNode> &node_based_edge_list,
std::vector<bool> &node_is_startpoint,

View File

@ -77,7 +77,6 @@ struct ExtractorConfig
intersection_class_data_output_path = basepath + ".osrm.icd";
}
boost::filesystem::path config_file_path;
boost::filesystem::path input_path;
boost::filesystem::path profile_path;

View File

@ -8,7 +8,6 @@
#include <string>
#include <vector>
struct lua_State;
namespace osmium
{
class Relation;
@ -19,7 +18,7 @@ namespace osrm
namespace extractor
{
struct ProfileProperties;
class ScriptingEnvironment;
/**
* Parses the relations that represents turn restrictions.
@ -42,11 +41,10 @@ struct ProfileProperties;
class RestrictionParser
{
public:
RestrictionParser(lua_State *lua_state, const ProfileProperties &properties);
RestrictionParser(ScriptingEnvironment &scripting_environment);
boost::optional<InputRestrictionContainer> TryParse(const osmium::Relation &relation) const;
private:
void ReadRestrictionExceptions(lua_State *lua_state);
bool ShouldIgnoreRestriction(const std::string &except_tag_string) const;
std::vector<std::string> restriction_exceptions;

View File

@ -1,52 +1,70 @@
#ifndef SCRIPTING_ENVIRONMENT_HPP
#define SCRIPTING_ENVIRONMENT_HPP
#include "extractor/guidance/turn_lane_types.hpp"
#include "extractor/internal_extractor_edge.hpp"
#include "extractor/profile_properties.hpp"
#include "extractor/raster_source.hpp"
#include "extractor/restriction.hpp"
#include "util/lua_util.hpp"
#include <osmium/memory/buffer.hpp>
#include <boost/optional/optional.hpp>
#include <tbb/concurrent_vector.h>
#include <memory>
#include <mutex>
#include <string>
#include <tbb/enumerable_thread_specific.h>
#include <vector>
struct lua_State;
namespace osmium
{
class Node;
class Way;
}
namespace osrm
{
namespace util
{
struct Coordinate;
}
namespace extractor
{
class RestrictionParser;
struct ExtractionNode;
struct ExtractionWay;
/**
* Creates a lua context and binds osmium way, node and relation objects and
* ExtractionWay and ExtractionNode to lua objects.
*
* Each thread has its own lua state which is implemented with thread specific
* storage from TBB.
* Abstract class that handles processing osmium ways, nodes and relation objects by applying
* user supplied profiles.
*/
class ScriptingEnvironment
{
public:
struct Context
{
ProfileProperties properties;
SourceContainer sources;
util::LuaState state;
};
explicit ScriptingEnvironment(const std::string &file_name);
ScriptingEnvironment() = default;
ScriptingEnvironment(const ScriptingEnvironment &) = delete;
ScriptingEnvironment &operator=(const ScriptingEnvironment &) = delete;
virtual ~ScriptingEnvironment() = default;
Context &GetContex();
virtual const ProfileProperties &GetProfileProperties() = 0;
private:
void InitContext(Context &context);
std::mutex init_mutex;
std::string file_name;
tbb::enumerable_thread_specific<std::unique_ptr<Context>> script_contexts;
virtual std::vector<std::string> GetNameSuffixList() = 0;
virtual std::vector<std::string> GetExceptions() = 0;
virtual void SetupSources() = 0;
virtual int32_t GetTurnPenalty(double angle) = 0;
virtual void ProcessSegment(const osrm::util::Coordinate &source,
const osrm::util::Coordinate &target,
double distance,
InternalExtractorEdge::WeightData &weight) = 0;
virtual void
ProcessElements(const std::vector<osmium::memory::Buffer::const_iterator> &osm_elements,
const RestrictionParser &restriction_parser,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionNode>> &resulting_nodes,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionWay>> &resulting_ways,
tbb::concurrent_vector<boost::optional<InputRestrictionContainer>>
&resulting_restrictions) = 0;
};
}
}

View File

@ -0,0 +1,80 @@
#ifndef SCRIPTING_ENVIRONMENT_LUA_HPP
#define SCRIPTING_ENVIRONMENT_LUA_HPP
#include "extractor/scripting_environment.hpp"
#include "extractor/raster_source.hpp"
#include "util/lua_util.hpp"
#include <tbb/enumerable_thread_specific.h>
#include <memory>
#include <mutex>
#include <string>
struct lua_State;
namespace osrm
{
namespace extractor
{
struct LuaScriptingContext final
{
void processNode(const osmium::Node &, ExtractionNode &result);
void processWay(const osmium::Way &, ExtractionWay &result);
ProfileProperties properties;
SourceContainer sources;
util::LuaState state;
bool has_turn_penalty_function;
bool has_node_function;
bool has_way_function;
bool has_segment_function;
};
/**
* Creates a lua context and binds osmium way, node and relation objects and
* ExtractionWay and ExtractionNode to lua objects.
*
* Each thread has its own lua state which is implemented with thread specific
* storage from TBB.
*/
class LuaScriptingEnvironment final : public ScriptingEnvironment
{
public:
explicit LuaScriptingEnvironment(const std::string &file_name);
~LuaScriptingEnvironment() override = default;
const ProfileProperties& GetProfileProperties() override;
LuaScriptingContext &GetLuaContext();
std::vector<std::string> GetNameSuffixList() override;
std::vector<std::string> GetExceptions() override;
void SetupSources() override;
int32_t GetTurnPenalty(double angle) override;
void ProcessSegment(const osrm::util::Coordinate &source,
const osrm::util::Coordinate &target,
double distance,
InternalExtractorEdge::WeightData &weight) override;
void
ProcessElements(const std::vector<osmium::memory::Buffer::const_iterator> &osm_elements,
const RestrictionParser &restriction_parser,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionNode>> &resulting_nodes,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionWay>> &resulting_ways,
tbb::concurrent_vector<boost::optional<InputRestrictionContainer>>
&resulting_restrictions) override;
private:
void InitContext(LuaScriptingContext &context);
std::mutex init_mutex;
std::string file_name;
tbb::enumerable_thread_specific<std::unique_ptr<LuaScriptingContext>> script_contexts;
};
}
}
#endif /* SCRIPTING_ENVIRONMENT_LUA_HPP */

View File

@ -4,19 +4,20 @@
#include <string>
#include <unordered_set>
struct lua_State;
namespace osrm
{
namespace extractor
{
class ScriptingEnvironment;
// A table containing suffixes.
// At the moment, it is only a front for an unordered set. At some point we might want to make it
// country dependent and have it behave accordingly
class SuffixTable final
{
public:
SuffixTable(lua_State *lua_state);
SuffixTable(ScriptingEnvironment &scripting_environment);
// check whether a string is part of the know suffix list
bool isSuffix(const std::string &possible_suffix) const;

View File

@ -4,7 +4,6 @@
#include "util/coordinate_calculation.hpp"
#include "util/exception.hpp"
#include "util/integer_range.hpp"
#include "util/lua_util.hpp"
#include "util/percent.hpp"
#include "util/simple_logger.hpp"
#include "util/timing_util.hpp"
@ -12,6 +11,7 @@
#include "extractor/guidance/toolkit.hpp"
#include "extractor/guidance/turn_analysis.hpp"
#include "extractor/guidance/turn_lane_handler.hpp"
#include "extractor/scripting_environment.hpp"
#include "extractor/suffix_table.hpp"
#include <boost/assert.hpp>
@ -182,9 +182,9 @@ void EdgeBasedGraphFactory::FlushVectorToStream(
original_edge_data_vector.clear();
}
void EdgeBasedGraphFactory::Run(const std::string &original_edge_data_filename,
void EdgeBasedGraphFactory::Run(ScriptingEnvironment &scripting_environment,
const std::string &original_edge_data_filename,
const std::string &turn_lane_data_filename,
lua_State *lua_state,
const std::string &edge_segment_lookup_filename,
const std::string &edge_penalty_filename,
const bool generate_edge_lookup)
@ -199,9 +199,9 @@ void EdgeBasedGraphFactory::Run(const std::string &original_edge_data_filename,
TIMER_STOP(generate_nodes);
TIMER_START(generate_edges);
GenerateEdgeExpandedEdges(original_edge_data_filename,
GenerateEdgeExpandedEdges(scripting_environment,
original_edge_data_filename,
turn_lane_data_filename,
lua_state,
edge_segment_lookup_filename,
edge_penalty_filename,
generate_edge_lookup);
@ -298,18 +298,15 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedNodes()
/// Actually it also generates OriginalEdgeData and serializes them...
void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges(
ScriptingEnvironment &scripting_environment,
const std::string &original_edge_data_filename,
const std::string &turn_lane_data_filename,
lua_State *lua_state,
const std::string &edge_segment_lookup_filename,
const std::string &edge_fixed_penalties_filename,
const bool generate_edge_lookup)
{
util::SimpleLogger().Write() << "generating edge-expanded edges";
BOOST_ASSERT(lua_state != nullptr);
const bool use_turn_function = util::luaFunctionExists(lua_state, "turn_function");
std::size_t node_based_edge_counter = 0;
std::size_t original_edges_counter = 0;
restricted_turns_counter = 0;
@ -338,7 +335,7 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges(
// Three nested loop look super-linear, but we are dealing with a (kind of)
// linear number of turns only.
util::Percent progress(m_node_based_graph->GetNumberOfNodes());
SuffixTable street_name_suffix_table(lua_state);
SuffixTable street_name_suffix_table(scripting_environment);
guidance::TurnAnalysis turn_analysis(*m_node_based_graph,
m_node_info_list,
*m_restriction_map,
@ -410,8 +407,6 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges(
for (const auto turn : possible_turns)
{
const double turn_angle = turn.angle;
// only add an edge if turn is not prohibited
const EdgeData &edge_data1 = m_node_based_graph->GetEdgeData(edge_from_u);
const EdgeData &edge_data2 = m_node_based_graph->GetEdgeData(turn.eid);
@ -427,8 +422,7 @@ void EdgeBasedGraphFactory::GenerateEdgeExpandedEdges(
distance += profile_properties.traffic_signal_penalty;
}
const int turn_penalty =
use_turn_function ? GetTurnPenalty(turn_angle, lua_state) : 0;
const int32_t turn_penalty = scripting_environment.GetTurnPenalty(180. - turn.angle);
const auto turn_instruction = turn.instruction;
if (guidance::isUturn(turn_instruction))
@ -615,23 +609,5 @@ std::vector<util::guidance::EntryClass> EdgeBasedGraphFactory::GetEntryClasses()
return result;
}
int EdgeBasedGraphFactory::GetTurnPenalty(double angle, lua_State *lua_state) const
{
BOOST_ASSERT(lua_state != nullptr);
try
{
// call lua profile to compute turn penalty
double penalty = luabind::call_function<double>(lua_state, "turn_function", 180. - angle);
BOOST_ASSERT(penalty < std::numeric_limits<int>::max());
BOOST_ASSERT(penalty > std::numeric_limits<int>::min());
return boost::numeric_cast<int>(penalty);
}
catch (const luabind::error &er)
{
util::SimpleLogger().Write(logWARNING) << er.what();
}
return 0;
}
} // namespace extractor
} // namespace osrm

View File

@ -7,7 +7,6 @@
#include "util/exception.hpp"
#include "util/fingerprint.hpp"
#include "util/io.hpp"
#include "util/lua_util.hpp"
#include "util/simple_logger.hpp"
#include "util/timing_util.hpp"
@ -17,8 +16,6 @@
#include <boost/numeric/conversion/cast.hpp>
#include <boost/ref.hpp>
#include <luabind/luabind.hpp>
#include <stxxl/sort>
#include <chrono>
@ -137,11 +134,11 @@ ExtractionContainers::ExtractionContainers()
* - merge edges with nodes to include location of start/end points and serialize
*
*/
void ExtractionContainers::PrepareData(const std::string &output_file_name,
void ExtractionContainers::PrepareData(ScriptingEnvironment &scripting_environment,
const std::string &output_file_name,
const std::string &restrictions_file_name,
const std::string &name_file_name,
const std::string &turn_lane_file_name,
lua_State *segment_state)
const std::string &turn_lane_file_name)
{
try
{
@ -152,7 +149,7 @@ void ExtractionContainers::PrepareData(const std::string &output_file_name,
PrepareNodes();
WriteNodes(file_out_stream);
PrepareEdges(segment_state);
PrepareEdges(scripting_environment);
WriteEdges(file_out_stream);
PrepareRestrictions();
@ -304,7 +301,7 @@ void ExtractionContainers::PrepareNodes()
std::cout << "ok, after " << TIMER_SEC(id_map) << "s" << std::endl;
}
void ExtractionContainers::PrepareEdges(lua_State *segment_state)
void ExtractionContainers::PrepareEdges(ScriptingEnvironment &scripting_environment)
{
// Sort edges by start.
std::cout << "[extractor] Sorting edges by start ... " << std::flush;
@ -386,8 +383,6 @@ void ExtractionContainers::PrepareEdges(lua_State *segment_state)
const auto all_edges_list_end_ = all_edges_list.end();
const auto all_nodes_list_end_ = all_nodes_list.end();
const auto has_segment_function = util::luaFunctionExists(segment_state, "segment_function");
while (edge_iterator != all_edges_list_end_ && node_iterator != all_nodes_list_end_)
{
// skip all invalid edges
@ -423,15 +418,8 @@ void ExtractionContainers::PrepareEdges(lua_State *segment_state)
edge_iterator->source_coordinate,
util::Coordinate(node_iterator->lon, node_iterator->lat));
if (has_segment_function)
{
luabind::call_function<void>(segment_state,
"segment_function",
boost::cref(edge_iterator->source_coordinate),
boost::cref(*node_iterator),
distance,
boost::ref(edge_iterator->weight_data));
}
scripting_environment.ProcessSegment(
edge_iterator->source_coordinate, *node_iterator, distance, edge_iterator->weight_data);
const double weight = [distance](const InternalExtractorEdge::WeightData &data) {
switch (data.type)

View File

@ -11,7 +11,6 @@
#include "extractor/raster_source.hpp"
#include "util/graph_loader.hpp"
#include "util/io.hpp"
#include "util/lua_util.hpp"
#include "util/make_unique.hpp"
#include "util/name_table.hpp"
#include "util/range_table.hpp"
@ -29,11 +28,9 @@
#include <boost/filesystem/fstream.hpp>
#include <boost/optional/optional.hpp>
#include <luabind/luabind.hpp>
#include <osmium/io/any_input.hpp>
#include <tbb/parallel_for.h>
#include <tbb/concurrent_vector.h>
#include <tbb/task_scheduler_init.h>
#include <cstdlib>
@ -74,11 +71,8 @@ namespace extractor
* graph
*
*/
int Extractor::run()
int Extractor::run(ScriptingEnvironment &scripting_environment)
{
// setup scripting environment
ScriptingEnvironment scripting_environment(config.profile_path.string().c_str());
try
{
util::LogPolicy::GetInstance().Unmute();
@ -90,7 +84,9 @@ int Extractor::run()
tbb::task_scheduler_init init(number_of_threads);
util::SimpleLogger().Write() << "Input file: " << config.input_path.filename().string();
if (!config.profile_path.empty()) {
util::SimpleLogger().Write() << "Profile: " << config.profile_path.filename().string();
}
util::SimpleLogger().Write() << "Threads: " << number_of_threads;
ExtractionContainers extraction_containers;
@ -100,21 +96,15 @@ int Extractor::run()
osmium::io::Reader reader(input_file);
const osmium::io::Header header = reader.header();
std::atomic<unsigned> number_of_nodes{0};
std::atomic<unsigned> number_of_ways{0};
std::atomic<unsigned> number_of_relations{0};
std::atomic<unsigned> number_of_others{0};
unsigned number_of_nodes = 0;
unsigned number_of_ways = 0;
unsigned number_of_relations = 0;
util::SimpleLogger().Write() << "Parsing in progress..";
TIMER_START(parsing);
auto &main_context = scripting_environment.GetContex();
// setup raster sources
if (util::luaFunctionExists(main_context.state, "source_function"))
{
luabind::call_function<void>(main_context.state, "source_function");
}
scripting_environment.SetupSources();
std::string generator = header.get("generator");
if (generator.empty())
@ -140,7 +130,7 @@ int Extractor::run()
tbb::concurrent_vector<boost::optional<InputRestrictionContainer>> resulting_restrictions;
// setup restriction parser
const RestrictionParser restriction_parser(main_context.state, main_context.properties);
const RestrictionParser restriction_parser(scripting_environment);
while (const osmium::memory::Buffer buffer = reader.read())
{
@ -156,52 +146,14 @@ int Extractor::run()
resulting_ways.clear();
resulting_restrictions.clear();
// parse OSM entities in parallel, store in resulting vectors
tbb::parallel_for(
tbb::blocked_range<std::size_t>(0, osm_elements.size()),
[&](const tbb::blocked_range<std::size_t> &range) {
ExtractionNode result_node;
ExtractionWay result_way;
auto &local_context = scripting_environment.GetContex();
scripting_environment.ProcessElements(osm_elements,
restriction_parser,
resulting_nodes,
resulting_ways,
resulting_restrictions);
for (auto x = range.begin(), end = range.end(); x != end; ++x)
{
const auto entity = osm_elements[x];
switch (entity->type())
{
case osmium::item_type::node:
result_node.clear();
++number_of_nodes;
luabind::call_function<void>(
local_context.state,
"node_function",
boost::cref(static_cast<const osmium::Node &>(*entity)),
boost::ref(result_node));
resulting_nodes.push_back(std::make_pair(x, std::move(result_node)));
break;
case osmium::item_type::way:
result_way.clear();
++number_of_ways;
luabind::call_function<void>(
local_context.state,
"way_function",
boost::cref(static_cast<const osmium::Way &>(*entity)),
boost::ref(result_way));
resulting_ways.push_back(std::make_pair(x, std::move(result_way)));
break;
case osmium::item_type::relation:
++number_of_relations;
resulting_restrictions.push_back(restriction_parser.TryParse(
static_cast<const osmium::Relation &>(*entity)));
break;
default:
++number_of_others;
break;
}
}
});
number_of_nodes += resulting_nodes.size();
// put parsed objects thru extractor callbacks
for (const auto &result : resulting_nodes)
{
@ -209,11 +161,13 @@ int Extractor::run()
static_cast<const osmium::Node &>(*(osm_elements[result.first])),
result.second);
}
number_of_ways += resulting_ways.size();
for (const auto &result : resulting_ways)
{
extractor_callbacks->ProcessWay(
static_cast<const osmium::Way &>(*(osm_elements[result.first])), result.second);
}
number_of_relations += resulting_restrictions.size();
for (const auto &result : resulting_restrictions)
{
extractor_callbacks->ProcessRestriction(result);
@ -223,10 +177,9 @@ int Extractor::run()
util::SimpleLogger().Write() << "Parsing finished after " << TIMER_SEC(parsing)
<< " seconds";
util::SimpleLogger().Write() << "Raw input contains " << number_of_nodes.load()
<< " nodes, " << number_of_ways.load() << " ways, and "
<< number_of_relations.load() << " relations, and "
<< number_of_others.load() << " unknown entities";
util::SimpleLogger().Write() << "Raw input contains " << number_of_nodes << " nodes, "
<< number_of_ways << " ways, and " << number_of_relations
<< " relations";
extractor_callbacks.reset();
@ -236,13 +189,14 @@ int Extractor::run()
return 1;
}
extraction_containers.PrepareData(config.output_file_name,
extraction_containers.PrepareData(scripting_environment,
config.output_file_name,
config.restriction_file_name,
config.names_file_name,
config.turn_lane_descriptions_file_name,
main_context.state);
config.turn_lane_descriptions_file_name);
WriteProfileProperties(config.profile_properties_output_path, main_context.properties);
WriteProfileProperties(config.profile_properties_output_path,
scripting_environment.GetProfileProperties());
TIMER_STOP(extracting);
util::SimpleLogger().Write() << "extraction finished after " << TIMER_SEC(extracting)
@ -261,9 +215,6 @@ int Extractor::run()
// that is better for routing. Every edge becomes a node, and every valid
// movement (e.g. turn from A->B, and B->A) becomes an edge
//
auto &main_context = scripting_environment.GetContex();
util::SimpleLogger().Write() << "Generating edge-expanded graph representation";
TIMER_START(expansion);
@ -273,8 +224,7 @@ int Extractor::run()
std::vector<bool> node_is_startpoint;
std::vector<EdgeWeight> edge_based_node_weights;
std::vector<QueryNode> internal_to_external_node_map;
auto graph_size = BuildEdgeExpandedGraph(main_context.state,
main_context.properties,
auto graph_size = BuildEdgeExpandedGraph(scripting_environment,
internal_to_external_node_map,
edge_based_node_list,
node_is_startpoint,
@ -477,8 +427,7 @@ Extractor::LoadNodeBasedGraph(std::unordered_set<NodeID> &barrier_nodes,
\brief Building an edge-expanded graph from node-based input and turn restrictions
*/
std::pair<std::size_t, EdgeID>
Extractor::BuildEdgeExpandedGraph(lua_State *lua_state,
const ProfileProperties &profile_properties,
Extractor::BuildEdgeExpandedGraph(ScriptingEnvironment &scripting_environment,
std::vector<QueryNode> &internal_to_external_node_map,
std::vector<EdgeBasedNode> &node_based_edge_list,
std::vector<bool> &node_is_startpoint,
@ -520,14 +469,14 @@ Extractor::BuildEdgeExpandedGraph(lua_State *lua_state,
traffic_lights,
std::const_pointer_cast<RestrictionMap const>(restriction_map),
internal_to_external_node_map,
profile_properties,
scripting_environment.GetProfileProperties(),
name_table,
turn_lane_offsets,
turn_lane_masks);
edge_based_graph_factory.Run(config.edge_output_path,
edge_based_graph_factory.Run(scripting_environment,
config.edge_output_path,
config.turn_lane_data_file_name,
lua_state,
config.edge_segment_lookup_path,
config.edge_penalty_path,
config.generate_edge_lookup);

View File

@ -1,9 +1,9 @@
#include "extractor/restriction_parser.hpp"
#include "extractor/profile_properties.hpp"
#include "extractor/scripting_environment.hpp"
#include "extractor/external_memory_node.hpp"
#include "util/exception.hpp"
#include "util/lua_util.hpp"
#include "util/simple_logger.hpp"
#include <boost/algorithm/string.hpp>
@ -24,33 +24,15 @@ namespace osrm
namespace extractor
{
namespace
{
int luaErrorCallback(lua_State *lua_state)
{
std::string error_msg = lua_tostring(lua_state, -1);
throw util::exception("ERROR occurred in profile script:\n" + error_msg);
}
}
RestrictionParser::RestrictionParser(lua_State *lua_state, const ProfileProperties &properties)
: use_turn_restrictions(properties.use_turn_restrictions)
RestrictionParser::RestrictionParser(ScriptingEnvironment &scripting_environment)
: use_turn_restrictions(scripting_environment.GetProfileProperties().use_turn_restrictions)
{
if (use_turn_restrictions)
{
ReadRestrictionExceptions(lua_state);
}
}
void RestrictionParser::ReadRestrictionExceptions(lua_State *lua_state)
{
if (util::luaFunctionExists(lua_state, "get_exceptions"))
{
luabind::set_pcall_callback(&luaErrorCallback);
// get list of turn restriction exceptions
luabind::call_function<void>(
lua_state, "get_exceptions", boost::ref(restriction_exceptions));
restriction_exceptions = scripting_environment.GetExceptions();
const unsigned exception_count = restriction_exceptions.size();
if (exception_count)
{
util::SimpleLogger().Write() << "Found " << exception_count
<< " exceptions to turn restrictions:";
for (const std::string &str : restriction_exceptions)
@ -63,6 +45,7 @@ void RestrictionParser::ReadRestrictionExceptions(lua_State *lua_state)
util::SimpleLogger().Write() << "Found no exceptions to turn restrictions";
}
}
}
/**
* Tries to parse an relation as turn restriction. This can fail for a number of

View File

@ -1,4 +1,4 @@
#include "extractor/scripting_environment.hpp"
#include "extractor/scripting_environment_lua.hpp"
#include "extractor/external_memory_node.hpp"
#include "extractor/extraction_helper_functions.hpp"
@ -7,6 +7,7 @@
#include "extractor/internal_extractor_edge.hpp"
#include "extractor/profile_properties.hpp"
#include "extractor/raster_source.hpp"
#include "extractor/restriction_parser.hpp"
#include "util/exception.hpp"
#include "util/lua_util.hpp"
#include "util/make_unique.hpp"
@ -19,6 +20,8 @@
#include <osmium/osm.hpp>
#include <tbb/parallel_for.h>
#include <sstream>
namespace osrm
@ -58,12 +61,13 @@ int luaErrorCallback(lua_State *state)
}
}
ScriptingEnvironment::ScriptingEnvironment(const std::string &file_name) : file_name(file_name)
LuaScriptingEnvironment::LuaScriptingEnvironment(const std::string &file_name)
: file_name(file_name)
{
util::SimpleLogger().Write() << "Using script " << file_name;
}
void ScriptingEnvironment::InitContext(ScriptingEnvironment::Context &context)
void LuaScriptingEnvironment::InitContext(LuaScriptingContext &context)
{
typedef double (osmium::Location::*location_member_ptr_type)() const;
@ -187,21 +191,180 @@ void ScriptingEnvironment::InitContext(ScriptingEnvironment::Context &context)
error_stream << error_msg;
throw util::exception("ERROR occurred in profile script:\n" + error_stream.str());
}
context.has_turn_penalty_function = util::luaFunctionExists(context.state, "turn_function");
context.has_node_function = util::luaFunctionExists(context.state, "node_function");
context.has_way_function = util::luaFunctionExists(context.state, "way_function");
context.has_segment_function = util::luaFunctionExists(context.state, "segment_function");
}
ScriptingEnvironment::Context &ScriptingEnvironment::GetContex()
const ProfileProperties &LuaScriptingEnvironment::GetProfileProperties()
{
return GetLuaContext().properties;
}
LuaScriptingContext &LuaScriptingEnvironment::GetLuaContext()
{
std::lock_guard<std::mutex> lock(init_mutex);
bool initialized = false;
auto &ref = script_contexts.local(initialized);
if (!initialized)
{
ref = util::make_unique<Context>();
ref = util::make_unique<LuaScriptingContext>();
InitContext(*ref);
}
luabind::set_pcall_callback(&luaErrorCallback);
return *ref;
}
void LuaScriptingEnvironment::ProcessElements(
const std::vector<osmium::memory::Buffer::const_iterator> &osm_elements,
const RestrictionParser &restriction_parser,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionNode>> &resulting_nodes,
tbb::concurrent_vector<std::pair<std::size_t, ExtractionWay>> &resulting_ways,
tbb::concurrent_vector<boost::optional<InputRestrictionContainer>> &resulting_restrictions)
{
// parse OSM entities in parallel, store in resulting vectors
tbb::parallel_for(
tbb::blocked_range<std::size_t>(0, osm_elements.size()),
[&](const tbb::blocked_range<std::size_t> &range) {
ExtractionNode result_node;
ExtractionWay result_way;
auto &local_context = this->GetLuaContext();
for (auto x = range.begin(), end = range.end(); x != end; ++x)
{
const auto entity = osm_elements[x];
switch (entity->type())
{
case osmium::item_type::node:
result_node.clear();
if (local_context.has_node_function)
{
local_context.processNode(static_cast<const osmium::Node &>(*entity),
result_node);
}
resulting_nodes.push_back(std::make_pair(x, std::move(result_node)));
break;
case osmium::item_type::way:
result_way.clear();
if (local_context.has_way_function)
{
local_context.processWay(static_cast<const osmium::Way &>(*entity),
result_way);
}
resulting_ways.push_back(std::make_pair(x, std::move(result_way)));
break;
case osmium::item_type::relation:
resulting_restrictions.push_back(restriction_parser.TryParse(
static_cast<const osmium::Relation &>(*entity)));
break;
default:
break;
}
}
});
}
std::vector<std::string> LuaScriptingEnvironment::GetNameSuffixList()
{
auto &context = GetLuaContext();
BOOST_ASSERT(context.state != nullptr);
if (!util::luaFunctionExists(context.state, "get_name_suffix_list"))
return {};
std::vector<std::string> suffixes_vector;
try
{
// call lua profile to compute turn penalty
luabind::call_function<void>(
context.state, "get_name_suffix_list", boost::ref(suffixes_vector));
}
catch (const luabind::error &er)
{
util::SimpleLogger().Write(logWARNING) << er.what();
}
return suffixes_vector;
}
std::vector<std::string> LuaScriptingEnvironment::GetExceptions()
{
auto &context = GetLuaContext();
BOOST_ASSERT(context.state != nullptr);
std::vector<std::string> restriction_exceptions;
if (util::luaFunctionExists(context.state, "get_exceptions"))
{
// get list of turn restriction exceptions
luabind::call_function<void>(
context.state, "get_exceptions", boost::ref(restriction_exceptions));
}
return restriction_exceptions;
}
void LuaScriptingEnvironment::SetupSources()
{
auto &context = GetLuaContext();
BOOST_ASSERT(context.state != nullptr);
if (util::luaFunctionExists(context.state, "source_function"))
{
luabind::call_function<void>(context.state, "source_function");
}
}
int32_t LuaScriptingEnvironment::GetTurnPenalty(const double angle)
{
auto &context = GetLuaContext();
if (context.has_turn_penalty_function)
{
BOOST_ASSERT(context.state != nullptr);
try
{
// call lua profile to compute turn penalty
const double penalty =
luabind::call_function<double>(context.state, "turn_function", angle);
BOOST_ASSERT(penalty < std::numeric_limits<int32_t>::max());
BOOST_ASSERT(penalty > std::numeric_limits<int32_t>::min());
return boost::numeric_cast<int32_t>(penalty);
}
catch (const luabind::error &er)
{
util::SimpleLogger().Write(logWARNING) << er.what();
}
}
return 0;
}
void LuaScriptingEnvironment::ProcessSegment(const osrm::util::Coordinate &source,
const osrm::util::Coordinate &target,
double distance,
InternalExtractorEdge::WeightData &weight)
{
auto &context = GetLuaContext();
if (context.has_segment_function)
{
BOOST_ASSERT(context.state != nullptr);
luabind::call_function<void>(context.state,
"segment_function",
boost::cref(source),
boost::cref(target),
distance,
boost::ref(weight));
}
}
void LuaScriptingContext::processNode(const osmium::Node &node, ExtractionNode &result)
{
BOOST_ASSERT(state != nullptr);
luabind::call_function<void>(state, "node_function", boost::cref(node), boost::ref(result));
}
void LuaScriptingContext::processWay(const osmium::Way &way, ExtractionWay &result)
{
BOOST_ASSERT(state != nullptr);
luabind::call_function<void>(state, "way_function", boost::cref(way), boost::ref(result));
}
}
}

View File

@ -1,38 +1,17 @@
#include "extractor/suffix_table.hpp"
#include "util/lua_util.hpp"
#include "util/simple_logger.hpp"
#include "extractor/scripting_environment.hpp"
#include <boost/algorithm/string.hpp>
#include <boost/assert.hpp>
#include <boost/ref.hpp>
#include <iterator>
#include <vector>
namespace osrm
{
namespace extractor
{
SuffixTable::SuffixTable(lua_State *lua_state)
SuffixTable::SuffixTable(ScriptingEnvironment &scripting_environment)
{
BOOST_ASSERT(lua_state != nullptr);
if (!util::luaFunctionExists(lua_state, "get_name_suffix_list"))
return;
std::vector<std::string> suffixes_vector;
try
{
// call lua profile to compute turn penalty
luabind::call_function<void>(
lua_state, "get_name_suffix_list", boost::ref(suffixes_vector));
}
catch (const luabind::error &er)
{
util::SimpleLogger().Write(logWARNING) << er.what();
}
std::vector<std::string> suffixes_vector = scripting_environment.GetNameSuffixList();
for (auto &suffix : suffixes_vector)
boost::algorithm::to_lower(suffix);
suffix_set.insert(std::begin(suffixes_vector), std::end(suffixes_vector));

View File

@ -1,5 +1,6 @@
#include "extractor/extractor.hpp"
#include "extractor/extractor_config.hpp"
#include "extractor/scripting_environment_lua.hpp"
#include "util/simple_logger.hpp"
#include "util/version.hpp"
@ -147,7 +148,11 @@ int main(int argc, char *argv[]) try
<< "Profile " << extractor_config.profile_path.string() << " not found!";
return EXIT_FAILURE;
}
return extractor::Extractor(extractor_config).run();
// setup scripting environment
extractor::LuaScriptingEnvironment scripting_environment(
extractor_config.profile_path.string().c_str());
return extractor::Extractor(extractor_config).run(scripting_environment);
}
catch (const std::bad_alloc &e)
{