diff --git a/include/extractor/class_data.hpp b/include/extractor/class_data.hpp index 717a73600..fefee6bc6 100644 --- a/include/extractor/class_data.hpp +++ b/include/extractor/class_data.hpp @@ -12,6 +12,7 @@ namespace extractor using ClassData = std::uint8_t; static const std::uint8_t MAX_CLASS_INDEX = 8 - 1; +static const std::uint8_t MAX_AVOIDABLE_CLASSES = 8; inline bool isSubset(const ClassData lhs, const ClassData rhs) { return (lhs & rhs) == lhs; } diff --git a/include/extractor/profile_properties.hpp b/include/extractor/profile_properties.hpp index 98a7cf3aa..a29cff939 100644 --- a/include/extractor/profile_properties.hpp +++ b/include/extractor/profile_properties.hpp @@ -5,9 +5,11 @@ #include "util/typedefs.hpp" -#include #include +#include #include + +#include #include namespace osrm @@ -70,6 +72,24 @@ struct ProfileProperties return std::string(weight_name); } + // Mark this combination of classes as avoidable + void SetAvoidableClasses(std::size_t index, ClassData classes) + { + avoidable_classes[index] = classes; + } + + // Check if this classes are avoidable + boost::optional ClassesAreAvoidable(ClassData classes) + { + auto iter = std::find(avoidable_classes.begin(), avoidable_classes.end(), classes); + if (iter != avoidable_classes.end()) + { + return std::distance(avoidable_classes.begin(), iter); + } + + return {}; + } + void SetClassName(std::size_t index, const std::string &name) { char *name_ptr = class_names[index]; @@ -109,6 +129,8 @@ struct ProfileProperties char weight_name[MAX_WEIGHT_NAME_LENGTH + 1]; //! stores the names of each class std::array class_names; + //! stores the masks of avoidable class combinations + std::array avoidable_classes; unsigned weight_precision = 1; bool force_split_edges = false; bool call_tagless_node_function = true; diff --git a/include/extractor/scripting_environment.hpp b/include/extractor/scripting_environment.hpp index 01d31c482..371a86adf 100644 --- a/include/extractor/scripting_environment.hpp +++ b/include/extractor/scripting_environment.hpp @@ -52,6 +52,7 @@ class ScriptingEnvironment virtual const ProfileProperties &GetProfileProperties() = 0; + virtual std::vector> GetAvoidableClasses() = 0; virtual std::vector GetNameSuffixList() = 0; virtual std::vector GetRestrictions() = 0; virtual void ProcessTurn(ExtractionTurn &turn) = 0; diff --git a/include/extractor/scripting_environment_lua.hpp b/include/extractor/scripting_environment_lua.hpp index 3684f0857..8bdc83551 100644 --- a/include/extractor/scripting_environment_lua.hpp +++ b/include/extractor/scripting_environment_lua.hpp @@ -58,10 +58,7 @@ class Sol2ScriptingEnvironment final : public ScriptingEnvironment const ProfileProperties &GetProfileProperties() override; - LuaScriptingContext &GetSol2Context(); - - std::vector GetStringListFromTable(const std::string &table_name); - std::vector GetStringListFromFunction(const std::string &function_name); + std::vector> GetAvoidableClasses() override; std::vector GetNameSuffixList() override; std::vector GetRestrictions() override; void ProcessTurn(ExtractionTurn &turn) override; @@ -75,6 +72,12 @@ class Sol2ScriptingEnvironment final : public ScriptingEnvironment std::vector &resulting_restrictions) override; private: + LuaScriptingContext &GetSol2Context(); + + std::vector GetStringListFromTable(const std::string &table_name); + std::vector> GetStringListsFromTable(const std::string &table_name); + std::vector GetStringListFromFunction(const std::string &function_name); + void InitContext(LuaScriptingContext &context); std::mutex init_mutex; std::string file_name; diff --git a/profiles/car.lua b/profiles/car.lua index d68dde487..64dc33a04 100644 --- a/profiles/car.lua +++ b/profiles/car.lua @@ -100,6 +100,13 @@ function setup() 'vehicle' }, + -- classes to support for avoid flags + avoidable = Sequence { + Set {"toll"}, + Set {"motorway"}, + Set {"ferry"} + }, + avoid = Set { 'area', -- 'toll', -- uncomment this to avoid tolls diff --git a/src/extractor/extractor.cpp b/src/extractor/extractor.cpp index a540fe46d..74dd74685 100644 --- a/src/extractor/extractor.cpp +++ b/src/extractor/extractor.cpp @@ -78,6 +78,41 @@ void SetClassNames(const ExtractorCallbacks::ClassesMap &classes_map, profile_properties.SetClassName(range.front(), pair.first); } } + +// Converts the class name list to a mask list +void SetAvoidableClasses(const ExtractorCallbacks::ClassesMap &classes_map, + const std::vector> &avoidable_classes, + ProfileProperties &profile_properties) +{ + if (avoidable_classes.size() > MAX_AVOIDABLE_CLASSES) + { + throw util::exception("Only " + std::to_string(MAX_AVOIDABLE_CLASSES) + " avoidable combinations allowed."); + } + + std::size_t combination_index = 0; + for (const auto &combination : avoidable_classes) + { + ClassData mask = 0; + for (const auto &name : combination) + { + auto iter = classes_map.find(name); + if (iter == classes_map.end()) + { + util::Log(logWARNING) + << "Unknown class name " + name + " in avoidable combination. Ignoring."; + } + else + { + mask |= iter->second; + } + } + + if (mask > 0) + { + profile_properties.SetAvoidableClasses(combination_index++, mask); + } + } +} } /** @@ -341,6 +376,8 @@ Extractor::ParseOSMData(ScriptingEnvironment &scripting_environment, auto profile_properties = scripting_environment.GetProfileProperties(); SetClassNames(classes_map, profile_properties); + auto avoidable_classes = scripting_environment.GetAvoidableClasses(); + SetAvoidableClasses(classes_map, avoidable_classes, profile_properties); files::writeProfileProperties(config.GetPath(".osrm.properties").string(), profile_properties); TIMER_STOP(extracting); diff --git a/src/extractor/scripting_environment_lua.cpp b/src/extractor/scripting_environment_lua.cpp index e63ba0c39..8b71af05c 100644 --- a/src/extractor/scripting_environment_lua.cpp +++ b/src/extractor/scripting_environment_lua.cpp @@ -689,11 +689,55 @@ Sol2ScriptingEnvironment::GetStringListFromTable(const std::string &table_name) for (auto &&pair : table) { strings.push_back(pair.second.as()); - }; + } } return strings; } +std::vector> +Sol2ScriptingEnvironment::GetStringListsFromTable(const std::string &table_name) +{ + std::vector> string_lists; + + auto &context = GetSol2Context(); + BOOST_ASSERT(context.state.lua_state() != nullptr); + sol::table table = context.profile_table[table_name]; + if (!table.valid()) + { + return string_lists; + } + + for (const auto &pair : table) + { + sol::table inner_table = pair.second; + if (!inner_table.valid()) + { + throw util::exception("Expected a sub-table at " + table_name + "[" + pair.first.as() + "]"); + } + + std::vector inner_vector; + for (const auto &inner_pair : inner_table) + { + inner_vector.push_back(inner_pair.first.as()); + } + string_lists.push_back(std::move(inner_vector)); + } + + return string_lists; +} + +std::vector> Sol2ScriptingEnvironment::GetAvoidableClasses() +{ + auto &context = GetSol2Context(); + switch (context.api_version) + { + case 2: + return Sol2ScriptingEnvironment::GetStringListsFromTable("avoidable"); + default: + return {}; + } +} + std::vector Sol2ScriptingEnvironment::GetNameSuffixList() { auto &context = GetSol2Context();