diff --git a/include/extractor/scripting_environment_lua.hpp b/include/extractor/scripting_environment_lua.hpp index b81e9024c..ef5243454 100644 --- a/include/extractor/scripting_environment_lua.hpp +++ b/include/extractor/scripting_environment_lua.hpp @@ -41,10 +41,10 @@ struct LuaScriptingContext final bool has_way_function = false; bool has_segment_function = false; - sol::function turn_function; - sol::function way_function; - sol::function node_function; - sol::function segment_function; + sol::protected_function turn_function; + sol::protected_function way_function; + sol::protected_function node_function; + sol::protected_function segment_function; int api_version = 4; sol::table profile_table; diff --git a/src/extractor/scripting_environment_lua.cpp b/src/extractor/scripting_environment_lua.cpp index 0ab3df52d..2461f9e3c 100644 --- a/src/extractor/scripting_environment_lua.cpp +++ b/src/extractor/scripting_environment_lua.cpp @@ -91,6 +91,19 @@ struct to_lua_object : public boost::static_visitor }; } // namespace +// Handle a lua error thrown in a protected function by printing the traceback and bubbling +// exception up to caller. Lua errors are generally unrecoverable, so this exception should not be +// caught but instead should terminate the process. The point of having this error handler rather +// than just using unprotected Lua functions which terminate the process automatically is that this +// function provides more useful error messages including Lua tracebacks and line numbers. +void handle_lua_error(sol::protected_function_result &luares) +{ + sol::error luaerr = luares; + std::string msg = luaerr.what(); + std::cerr << msg << std::endl; + throw util::exception("Lua error (see stderr for traceback)"); +} + Sol2ScriptingEnvironment::Sol2ScriptingEnvironment( const std::string &file_name, const std::vector &location_dependent_data_paths) @@ -232,7 +245,8 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) "valid", &osmium::Location::valid); - auto get_location_tag = [](auto &context, const auto &location, const char *key) { + auto get_location_tag = [](auto &context, const auto &location, const char *key) + { if (context.location_dependent_data.empty()) return sol::object(context.state); @@ -259,7 +273,8 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) "get_nodes", [](const osmium::Way &way) { return sol::as_table(&way.nodes()); }, "get_location_tag", - [&context, &get_location_tag](const osmium::Way &way, const char *key) { + [&context, &get_location_tag](const osmium::Way &way, const char *key) + { // HEURISTIC: use a single node (last) of the way to localize the way // For more complicated scenarios a proper merging of multiple tags // at one or many locations must be provided @@ -279,9 +294,8 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) "version", &osmium::Node::version, "get_location_tag", - [&context, &get_location_tag](const osmium::Node &node, const char *key) { - return get_location_tag(context, node.location(), key); - }); + [&context, &get_location_tag](const osmium::Node &node, const char *key) + { return get_location_tag(context, node.location(), key); }); context.state.new_enum("traffic_lights", "none", @@ -297,7 +311,8 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) "ResultNode", "traffic_lights", sol::property([](const ExtractionNode &node) { return node.traffic_lights; }, - [](ExtractionNode &node, const sol::object &obj) { + [](ExtractionNode &node, const sol::object &obj) + { if (obj.is()) { // The old approach of assigning a boolean traffic light @@ -348,7 +363,8 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) sol::property(&ExtractionWay::GetName, &ExtractionWay::SetName), "ref", // backward compatibility sol::property(&ExtractionWay::GetForwardRef, - [](ExtractionWay &way, const char *ref) { + [](ExtractionWay &way, const char *ref) + { way.SetForwardRef(ref); way.SetBackwardRef(ref); }), @@ -407,7 +423,8 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) sol::property([](const ExtractionWay &way) { return way.access_turn_classification; }, [](ExtractionWay &way, int flag) { way.access_turn_classification = flag; })); - auto getTypedRefBySol = [](const sol::object &obj) -> ExtractionRelation::OsmIDTyped { + auto getTypedRefBySol = [](const sol::object &obj) -> ExtractionRelation::OsmIDTyped + { if (obj.is()) { osmium::Way *way = obj.as(); @@ -443,20 +460,19 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) "get_value_by_key", [](ExtractionRelation &rel, const char *key) -> const char * { return rel.GetAttr(key); }, "get_role", - [&getTypedRefBySol](ExtractionRelation &rel, const sol::object &obj) -> const char * { - return rel.GetRole(getTypedRefBySol(obj)); - }); + [&getTypedRefBySol](ExtractionRelation &rel, const sol::object &obj) -> const char * + { return rel.GetRole(getTypedRefBySol(obj)); }); context.state.new_usertype( "ExtractionRelationContainer", "get_relations", [&getTypedRefBySol](ExtractionRelationContainer &cont, const sol::object &obj) - -> const ExtractionRelationContainer::RelationIDList & { - return cont.GetRelations(getTypedRefBySol(obj)); - }, + -> const ExtractionRelationContainer::RelationIDList & + { return cont.GetRelations(getTypedRefBySol(obj)); }, "relation", - [](ExtractionRelationContainer &cont, const ExtractionRelation::OsmIDTyped &rel_id) - -> const ExtractionRelation & { return cont.GetRelationData(rel_id); }); + [](ExtractionRelationContainer &cont, + const ExtractionRelation::OsmIDTyped &rel_id) -> const ExtractionRelation & + { return cont.GetRelationData(rel_id); }); context.state.new_usertype("ExtractionSegment", "source", @@ -472,10 +488,12 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) // Keep in mind .location is available only if .pbf is preprocessed to set the location with the // ref using osmium command "osmium add-locations-to-ways" - context.state.new_usertype( - "NodeRef", "id", &osmium::NodeRef::ref, "location", [](const osmium::NodeRef &nref) { - return nref.location(); - }); + context.state.new_usertype("NodeRef", + "id", + &osmium::NodeRef::ref, + "location", + [](const osmium::NodeRef &nref) + { return nref.location(); }); context.state.new_usertype("EdgeSource", "source_coordinate", @@ -531,7 +549,8 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) util::Log() << "Using profile api version " << context.api_version; // version-dependent parts of the api - auto initV2Context = [&]() { + auto initV2Context = [&]() + { // clear global not used in v2 context.state["properties"] = sol::nullopt; @@ -550,10 +569,16 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) std::numeric_limits::max()); // call initialize function - sol::function setup_function = function_table.value()["setup"]; + sol::protected_function setup_function = function_table.value()["setup"]; if (!setup_function.valid()) throw util::exception("Profile must have an setup() function."); - sol::optional profile_table = setup_function(); + + auto setup_result = setup_function(); + + if (!setup_result.valid()) + handle_lua_error(setup_result); + + sol::optional profile_table = setup_result; if (profile_table == sol::nullopt) throw util::exception("Profile setup() must return a table."); else @@ -616,26 +641,31 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) } }; - auto initialize_V3_extraction_turn = [&]() { + auto initialize_V3_extraction_turn = [&]() + { context.state.new_usertype( "ExtractionTurn", "angle", &ExtractionTurn::angle, "turn_type", - sol::property([](const ExtractionTurn &turn) { - if (turn.number_of_roads > 2 || turn.source_mode != turn.target_mode || - turn.is_u_turn) - return osrm::guidance::TurnType::Turn; - else - return osrm::guidance::TurnType::NoTurn; - }), + sol::property( + [](const ExtractionTurn &turn) + { + if (turn.number_of_roads > 2 || turn.source_mode != turn.target_mode || + turn.is_u_turn) + return osrm::guidance::TurnType::Turn; + else + return osrm::guidance::TurnType::NoTurn; + }), "direction_modifier", - sol::property([](const ExtractionTurn &turn) { - if (turn.is_u_turn) - return osrm::guidance::DirectionModifier::UTurn; - else - return osrm::guidance::DirectionModifier::Straight; - }), + sol::property( + [](const ExtractionTurn &turn) + { + if (turn.is_u_turn) + return osrm::guidance::DirectionModifier::UTurn; + else + return osrm::guidance::DirectionModifier::Straight; + }), "has_traffic_light", &ExtractionTurn::has_traffic_light, "weight", @@ -845,10 +875,12 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) BOOST_ASSERT(context.properties.GetTrafficSignalPenalty() == 0); // call source_function if present - sol::function source_function = context.state["source_function"]; + sol::protected_function source_function = context.state["source_function"]; if (source_function.valid()) { - source_function(); + auto luares = source_function(); + if (!luares.valid()) + handle_lua_error(luares); } break; @@ -961,10 +993,12 @@ Sol2ScriptingEnvironment::GetStringListFromFunction(const std::string &function_ auto &context = GetSol2Context(); BOOST_ASSERT(context.state.lua_state()); std::vector strings; - sol::function function = context.state[function_name]; + sol::protected_function function = context.state[function_name]; if (function.valid()) { - function(strings); + auto luares = function(strings); + if (!luares.valid()) + handle_lua_error(luares); } return strings; } @@ -1123,7 +1157,9 @@ void Sol2ScriptingEnvironment::ProcessTurn(ExtractionTurn &turn) case 2: if (context.has_turn_penalty_function) { - context.turn_function(context.profile_table, std::ref(turn)); + auto luares = context.turn_function(context.profile_table, std::ref(turn)); + if (!luares.valid()) + handle_lua_error(luares); // Turn weight falls back to the duration value in deciseconds // or uses the extracted unit-less weight value @@ -1138,7 +1174,9 @@ void Sol2ScriptingEnvironment::ProcessTurn(ExtractionTurn &turn) case 1: if (context.has_turn_penalty_function) { - context.turn_function(std::ref(turn)); + auto luares = context.turn_function(std::ref(turn)); + if (!luares.valid()) + handle_lua_error(luares); // Turn weight falls back to the duration value in deciseconds // or uses the extracted unit-less weight value @@ -1184,24 +1222,28 @@ void Sol2ScriptingEnvironment::ProcessSegment(ExtractionSegment &segment) if (context.has_segment_function) { + sol::protected_function_result luares; switch (context.api_version) { case 4: case 3: case 2: - context.segment_function(context.profile_table, std::ref(segment)); + luares = context.segment_function(context.profile_table, std::ref(segment)); break; case 1: - context.segment_function(std::ref(segment)); + luares = context.segment_function(std::ref(segment)); break; case 0: - context.segment_function(std::ref(segment.source), - std::ref(segment.target), - segment.distance, - segment.duration); + luares = context.segment_function(std::ref(segment.source), + std::ref(segment.target), + segment.distance, + segment.duration); segment.weight = segment.duration; // back-compatibility fallback to duration break; } + + if (!luares.valid()) + handle_lua_error(luares); } } @@ -1211,20 +1253,27 @@ void LuaScriptingContext::ProcessNode(const osmium::Node &node, { BOOST_ASSERT(state.lua_state() != nullptr); + sol::protected_function_result luares; + + // TODO check for api version, make sure luares is always set switch (api_version) { case 4: case 3: - node_function(profile_table, std::cref(node), std::ref(result), std::cref(relations)); + luares = + node_function(profile_table, std::cref(node), std::ref(result), std::cref(relations)); break; case 2: - node_function(profile_table, std::cref(node), std::ref(result)); + luares = node_function(profile_table, std::cref(node), std::ref(result)); break; case 1: case 0: - node_function(std::cref(node), std::ref(result)); + luares = node_function(std::cref(node), std::ref(result)); break; } + + if (!luares.valid()) + handle_lua_error(luares); } void LuaScriptingContext::ProcessWay(const osmium::Way &way, @@ -1233,20 +1282,27 @@ void LuaScriptingContext::ProcessWay(const osmium::Way &way, { BOOST_ASSERT(state.lua_state() != nullptr); + sol::protected_function_result luares; + + // TODO check for api version, make sure luares is always set switch (api_version) { case 4: case 3: - way_function(profile_table, std::cref(way), std::ref(result), std::cref(relations)); + luares = + way_function(profile_table, std::cref(way), std::ref(result), std::cref(relations)); break; case 2: - way_function(profile_table, std::cref(way), std::ref(result)); + luares = way_function(profile_table, std::cref(way), std::ref(result)); break; case 1: case 0: - way_function(std::cref(way), std::ref(result)); + luares = way_function(std::cref(way), std::ref(result)); break; } + + if (!luares.valid()) + handle_lua_error(luares); } } // namespace osrm::extractor