diff --git a/include/extractor/scripting_environment.hpp b/include/extractor/scripting_environment.hpp index 97aad0173..8a686e39d 100644 --- a/include/extractor/scripting_environment.hpp +++ b/include/extractor/scripting_environment.hpp @@ -19,6 +19,7 @@ namespace osmium { class Node; class Way; +class Relation; } namespace osrm @@ -35,6 +36,7 @@ namespace extractor class RestrictionParser; struct ExtractionNode; struct ExtractionWay; +struct ExtractionRelation; struct ExtractionTurn; struct ExtractionSegment; @@ -64,6 +66,7 @@ class ScriptingEnvironment const RestrictionParser &restriction_parser, std::vector> &resulting_nodes, std::vector> &resulting_ways, + std::vector> &resulting_relations, std::vector &resulting_restrictions) = 0; }; } diff --git a/include/extractor/scripting_environment_lua.hpp b/include/extractor/scripting_environment_lua.hpp index ecc276518..95b7e3b23 100644 --- a/include/extractor/scripting_environment_lua.hpp +++ b/include/extractor/scripting_environment_lua.hpp @@ -21,6 +21,7 @@ struct LuaScriptingContext final { void ProcessNode(const osmium::Node &, ExtractionNode &result); void ProcessWay(const osmium::Way &, ExtractionWay &result); + void ProcessRelation(const osmium::Relation &, ExtractionRelation &result); ProfileProperties properties; RasterContainer raster_sources; @@ -29,11 +30,13 @@ struct LuaScriptingContext final bool has_turn_penalty_function; bool has_node_function; bool has_way_function; + bool has_relation_function; bool has_segment_function; sol::function turn_function; sol::function way_function; sol::function node_function; + sol::function relation_function; sol::function segment_function; int api_version; @@ -51,7 +54,7 @@ class Sol2ScriptingEnvironment final : public ScriptingEnvironment { public: static const constexpr int SUPPORTED_MIN_API_VERSION = 0; - static const constexpr int SUPPORTED_MAX_API_VERSION = 2; + static const constexpr int SUPPORTED_MAX_API_VERSION = 3; explicit Sol2ScriptingEnvironment(const std::string &file_name); ~Sol2ScriptingEnvironment() override = default; @@ -70,6 +73,7 @@ class Sol2ScriptingEnvironment final : public ScriptingEnvironment const RestrictionParser &restriction_parser, std::vector> &resulting_nodes, std::vector> &resulting_ways, + std::vector> &resulting_relations, std::vector &resulting_restrictions) override; private: diff --git a/src/extractor/extractor.cpp b/src/extractor/extractor.cpp index a850bd144..61e16662c 100644 --- a/src/extractor/extractor.cpp +++ b/src/extractor/extractor.cpp @@ -4,6 +4,7 @@ #include "extractor/extraction_containers.hpp" #include "extractor/extraction_node.hpp" #include "extractor/extraction_way.hpp" +#include "extractor/extraction_relation.hpp" #include "extractor/extractor_callbacks.hpp" #include "extractor/files.hpp" #include "extractor/raster_source.hpp" @@ -346,6 +347,7 @@ Extractor::ParseOSMData(ScriptingEnvironment &scripting_environment, SharedBuffer buffer; std::vector> resulting_nodes; std::vector> resulting_ways; + std::vector> resulting_relations; std::vector resulting_restrictions; }; @@ -372,6 +374,7 @@ Extractor::ParseOSMData(ScriptingEnvironment &scripting_environment, restriction_parser, parsed_buffer->resulting_nodes, parsed_buffer->resulting_ways, + parsed_buffer->resulting_relations, parsed_buffer->resulting_restrictions); return parsed_buffer; }); diff --git a/src/extractor/scripting_environment_lua.cpp b/src/extractor/scripting_environment_lua.cpp index f2682fc86..28b04be68 100644 --- a/src/extractor/scripting_environment_lua.cpp +++ b/src/extractor/scripting_environment_lua.cpp @@ -2,6 +2,7 @@ #include "extractor/extraction_helper_functions.hpp" #include "extractor/extraction_node.hpp" +#include "extractor/extraction_relation.hpp" #include "extractor/extraction_segment.hpp" #include "extractor/extraction_turn.hpp" #include "extractor/extraction_way.hpp" @@ -31,6 +32,9 @@ template <> struct is_container : std::false_type template <> struct is_container : std::false_type { }; +template <> struct is_container : std::false_type +{ +}; } namespace osrm @@ -216,6 +220,14 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) "sharp_left", extractor::guidance::DirectionModifier::SharpLeft); + context.state.new_enum("item_type", + "node", + osmium::item_type::node, + "way", + osmium::item_type::way, + "relation", + osmium::item_type::relation); + context.state.new_usertype("raster", "load", &RasterContainer::LoadRasterSource, @@ -276,6 +288,71 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) "version", &osmium::Way::version); + struct RelationMemberWrap + { + explicit RelationMemberWrap(const osmium::RelationMember & member) + { + init(member); + } + + RelationMemberWrap() + : item_type(osmium::item_type::undefined) + { + } + + void init(const osmium::RelationMember & member) + { + role = member.role(); + item_type = member.type(); + id = member.ref(); + } + + const char* GetRole() const { return role.c_str(); } + osmium::item_type GetItemType() const { return item_type; } + osmium::object_id_type GetId() const { return id; } + + util::OsmIDTyped ref() const { return util::OsmIDTyped(id, std::uint8_t(item_type)); } + + std::string role; + osmium::item_type item_type; + osmium::object_id_type id; + }; + + context.state.new_usertype( + "OsmIDTypes", + "id", + &util::OsmIDTyped::GetID, + "type", + &util::OsmIDTyped::GetType + ); + + context.state.new_usertype( + "RelationMember", + "role", + &RelationMemberWrap::GetRole, + "item_type", + &RelationMemberWrap::GetItemType + ); + + /** TODO: make better solution with members iteration. + * For this moment, just make vector of RelationMember wrappers + */ + context.state.new_usertype("Relation", + "get_value_by_key", + &get_value_by_key, + "id", + &osmium::Relation::id, + "version", + &osmium::Relation::version, + "members", [](const osmium::Relation &rel) + { + std::vector members(rel.members().size()); + size_t i = 0; + for (const auto & m : rel.members()) + members[i++].init(m); + return sol::as_table(std::move(members)); + }); + context.state.new_usertype("Node", "location", &osmium::Node::location, @@ -284,7 +361,7 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) "id", &osmium::Node::id, "version", - &osmium::Way::version); + &osmium::Node::version); context.state.new_usertype("ResultNode", "traffic_lights", @@ -366,6 +443,37 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) sol::property([](const ExtractionWay &way) { return way.backward_restricted; }, [](ExtractionWay &way, bool flag) { way.backward_restricted = flag; })); + struct ExtractionRelationData + { + explicit ExtractionRelationData(ExtractionRelation::AttributesMap & attrs_) + : attrs(attrs_) + { + } + + ExtractionRelation::AttributesMap & attrs; + }; + + context.state.new_usertype( + "ExtractionRelationData", + "size", + [](const ExtractionRelationData & data) { return data.attrs.size(); }, + sol::meta_function::new_index, + [](ExtractionRelationData & data, const std::string & key, sol::stack_object object) { return data.attrs[key] = object.as(); }, + sol::meta_function::index, + [](ExtractionRelationData & data, const std::string & key) { return data.attrs[key]; } + ); + + context.state.new_usertype( + "ExtractionRelation", + sol::meta_function::new_index, + [](ExtractionRelation & rel, const RelationMemberWrap & member) { return ExtractionRelationData(rel.GetMember(member.ref())); }, + sol::meta_function::index, + [](ExtractionRelation & rel, const RelationMemberWrap & member) { return ExtractionRelationData(rel.GetMember(member.ref())); }, + "restriction", + sol::property([](const ExtractionRelation & rel) { return rel.is_restriction; }, + [](ExtractionRelation & rel, bool flag) { rel.is_restriction = flag; }) + ); + context.state.new_usertype("ExtractionSegment", "source", &ExtractionSegment::source, @@ -457,9 +565,7 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) util::Log() << "Using profile api version " << context.api_version; // version-dependent parts of the api - switch (context.api_version) - { - case 2: + auto initV2Context = [&]() { // clear global not used in v2 context.state["properties"] = sol::nullopt; @@ -543,6 +649,22 @@ void Sol2ScriptingEnvironment::InitContext(LuaScriptingContext &context) if (force_split_edges != sol::nullopt) context.properties.force_split_edges = force_split_edges.value(); } + }; + + switch (context.api_version) + { + case 3: + { + initV2Context(); + context.relation_function = function_table.value()["process_relation"]; + + context.has_relation_function = context.relation_function.valid(); + break; + } + + case 2: + { + initV2Context(); break; } case 1: @@ -617,10 +739,12 @@ void Sol2ScriptingEnvironment::ProcessElements( const RestrictionParser &restriction_parser, std::vector> &resulting_nodes, std::vector> &resulting_ways, + std::vector> &resulting_relations, std::vector &resulting_restrictions) { ExtractionNode result_node; ExtractionWay result_way; + ExtractionRelation result_relation; auto &local_context = this->GetSol2Context(); for (auto entity = buffer.cbegin(), end = buffer.cend(); entity != end; ++entity) @@ -655,6 +779,17 @@ void Sol2ScriptingEnvironment::ProcessElements( { resulting_restrictions.push_back(*result_res); } + + if (local_context.api_version > 2) + { + result_relation.clear(); + if (local_context.has_relation_function) + { + local_context.ProcessRelation(static_cast(*entity), result_relation); + } + resulting_relations.push_back(std::pair( + static_cast(*entity), std::move(result_relation))); + } } break; default: @@ -732,6 +867,7 @@ std::vector> Sol2ScriptingEnvironment::GetExcludableCla auto &context = GetSol2Context(); switch (context.api_version) { + case 3: case 2: return Sol2ScriptingEnvironment::GetStringListsFromTable("excludable"); default: @@ -744,6 +880,7 @@ std::vector Sol2ScriptingEnvironment::GetClassNames() auto &context = GetSol2Context(); switch (context.api_version) { + case 3: case 2: return Sol2ScriptingEnvironment::GetStringListFromTable("classes"); default: @@ -756,6 +893,7 @@ std::vector Sol2ScriptingEnvironment::GetNameSuffixList() auto &context = GetSol2Context(); switch (context.api_version) { + case 3: case 2: return Sol2ScriptingEnvironment::GetStringListFromTable("suffix_list"); case 1: @@ -770,6 +908,7 @@ std::vector Sol2ScriptingEnvironment::GetRestrictions() auto &context = GetSol2Context(); switch (context.api_version) { + case 3: case 2: return Sol2ScriptingEnvironment::GetStringListFromTable("restrictions"); case 1: @@ -785,6 +924,7 @@ void Sol2ScriptingEnvironment::ProcessTurn(ExtractionTurn &turn) switch (context.api_version) { + case 3: case 2: if (context.has_turn_penalty_function) { @@ -851,6 +991,7 @@ void Sol2ScriptingEnvironment::ProcessSegment(ExtractionSegment &segment) { switch (context.api_version) { + case 3: case 2: context.segment_function(context.profile_table, segment); break; @@ -872,6 +1013,9 @@ void LuaScriptingContext::ProcessNode(const osmium::Node &node, ExtractionNode & switch (api_version) { + case 3: +// BOOST_ASSERT(false); // TODO: implement me +// break; case 2: node_function(profile_table, node, result); break; @@ -888,6 +1032,9 @@ void LuaScriptingContext::ProcessWay(const osmium::Way &way, ExtractionWay &resu switch (api_version) { + case 3: +// BOOST_ASSERT(false); // TODO: implement me +// break; case 2: way_function(profile_table, way, result); break; @@ -897,5 +1044,14 @@ void LuaScriptingContext::ProcessWay(const osmium::Way &way, ExtractionWay &resu break; } } + +void LuaScriptingContext::ProcessRelation(const osmium::Relation &relation, ExtractionRelation &result) +{ + BOOST_ASSERT(state.lua_state() != nullptr); + BOOST_ASSERT(api_version > 2); + + relation_function(profile_table, relation, result); } -} + +} // namespace extractor +} // namespace osrm diff --git a/unit_tests/mocks/mock_scripting_environment.hpp b/unit_tests/mocks/mock_scripting_environment.hpp index f6a7d80fe..3f53e1521 100644 --- a/unit_tests/mocks/mock_scripting_environment.hpp +++ b/unit_tests/mocks/mock_scripting_environment.hpp @@ -38,6 +38,7 @@ class MockScriptingEnvironment : public extractor::ScriptingEnvironment const extractor::RestrictionParser &, std::vector> &, std::vector> &, + std::vector> &resulting_relations, std::vector &) override final { }