Port extractor to TBB

This commit is contained in:
Patrick Niklaus
2014-05-16 13:26:42 +02:00
parent d1fdc7061f
commit da1fd96d4e
6 changed files with 395 additions and 96 deletions
+1 -1
View File
@@ -42,7 +42,7 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
BaseParser::BaseParser(ExtractorCallbacks *extractor_callbacks,
ScriptingEnvironment &scripting_environment)
: extractor_callbacks(extractor_callbacks),
lua_state(scripting_environment.getLuaStateForThreadID(0)),
lua_state(scripting_environment.getLuaState()),
scripting_environment(scripting_environment), use_turn_restrictions(true)
{
ReadUseRestrictionsSetting();
+25 -16
View File
@@ -40,6 +40,8 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#include "../Util/SimpleLogger.h"
#include "../typedefs.h"
#include <tbb/parallel_for.h>
#include <osrm/Coordinate.h>
#include <zlib.h>
@@ -257,16 +259,18 @@ inline void PBFParser::parseDenseNode(ParserThreadData *thread_data)
denseTagIndex += 2;
}
}
#pragma omp parallel
{
const int thread_num = omp_get_thread_num();
#pragma omp parallel for schedule(guided)
for (int i = 0; i < number_of_nodes; ++i)
tbb::parallel_for(tbb::blocked_range<size_t>(0, extracted_nodes_vector.size()),
[this, &extracted_nodes_vector](const tbb::blocked_range<size_t>& range)
{
ImportNode &import_node = extracted_nodes_vector[i];
ParseNodeInLua(import_node, scripting_environment.getLuaStateForThreadID(thread_num));
lua_State* lua_state = this->scripting_environment.getLuaState();
for (size_t i = range.begin(); i != range.end(); i++)
{
ImportNode &import_node = extracted_nodes_vector[i];
ParseNodeInLua(import_node, lua_state);
}
}
}
);
for (const ImportNode &import_node : extracted_nodes_vector)
{
@@ -424,16 +428,21 @@ inline void PBFParser::parseWay(ParserThreadData *thread_data)
}
}
#pragma omp parallel for schedule(guided)
for (int i = 0; i < number_of_ways; ++i)
{
ExtractionWay &extraction_way = parsed_way_vector[i];
if (2 <= extraction_way.path.size())
// TODO: investigate if schedule guided will be handled by tbb automatically
tbb::parallel_for(tbb::blocked_range<size_t>(0, parsed_way_vector.size()),
[this, &parsed_way_vector](const tbb::blocked_range<size_t>& range)
{
ParseWayInLua(extraction_way,
scripting_environment.getLuaStateForThreadID(omp_get_thread_num()));
lua_State* lua_state = this->scripting_environment.getLuaState();
for (size_t i = range.begin(); i != range.end(); i++)
{
ExtractionWay &extraction_way = parsed_way_vector[i];
if (2 <= extraction_way.path.size())
{
ParseWayInLua(extraction_way, lua_state);
}
}
}
}
);
for (ExtractionWay &extraction_way : parsed_way_vector)
{
+73 -75
View File
@@ -38,87 +38,85 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
ScriptingEnvironment::ScriptingEnvironment() {}
ScriptingEnvironment::ScriptingEnvironment(const char *file_name)
: file_name(file_name)
{
SimpleLogger().Write() << "Using script " << file_name;
// Create a new lua state
for (int i = 0; i < omp_get_max_threads(); ++i)
{
lua_state_vector.push_back(luaL_newstate());
}
// Connect LuaBind to this lua state for all threads
#pragma omp parallel
{
lua_State *lua_state = getLuaStateForThreadID(omp_get_thread_num());
luabind::open(lua_state);
// open utility libraries string library;
luaL_openlibs(lua_state);
luaAddScriptFolderToLoadPath(lua_state, file_name);
// Add our function to the state's global scope
luabind::module(lua_state)[
luabind::def("print", LUA_print<std::string>),
luabind::def("durationIsValid", durationIsValid),
luabind::def("parseDuration", parseDuration)
];
luabind::module(lua_state)[luabind::class_<HashTable<std::string, std::string>>("keyVals")
.def("Add", &HashTable<std::string, std::string>::Add)
.def("Find", &HashTable<std::string, std::string>::Find)
.def("Holds", &HashTable<std::string, std::string>::Holds)];
luabind::module(lua_state)[luabind::class_<ImportNode>("Node")
.def(luabind::constructor<>())
.def_readwrite("lat", &ImportNode::lat)
.def_readwrite("lon", &ImportNode::lon)
.def_readonly("id", &ImportNode::id)
.def_readwrite("bollard", &ImportNode::bollard)
.def_readwrite("traffic_light", &ImportNode::trafficLight)
.def_readwrite("tags", &ImportNode::keyVals)];
luabind::module(lua_state)
[luabind::class_<ExtractionWay>("Way")
.def(luabind::constructor<>())
.def_readonly("id", &ExtractionWay::id)
.def_readwrite("name", &ExtractionWay::name)
.def_readwrite("speed", &ExtractionWay::speed)
.def_readwrite("backward_speed", &ExtractionWay::backward_speed)
.def_readwrite("duration", &ExtractionWay::duration)
.def_readwrite("type", &ExtractionWay::type)
.def_readwrite("access", &ExtractionWay::access)
.def_readwrite("roundabout", &ExtractionWay::roundabout)
.def_readwrite("is_access_restricted", &ExtractionWay::isAccessRestricted)
.def_readwrite("ignore_in_grid", &ExtractionWay::ignoreInGrid)
.def_readwrite("tags", &ExtractionWay::keyVals)
.def_readwrite("direction", &ExtractionWay::direction)
.enum_("constants")[
luabind::value("notSure", 0),
luabind::value("oneway", 1),
luabind::value("bidirectional", 2),
luabind::value("opposite", 3)
]];
// fails on c++11/OS X 10.9
luabind::module(lua_state)[luabind::class_<std::vector<std::string>>("vector").def(
"Add",
static_cast<void (std::vector<std::string>::*)(const std::string &)>(
&std::vector<std::string>::push_back))];
if (0 != luaL_dofile(lua_state, file_name))
{
throw OSRMException("ERROR occured in scripting block");
}
}
}
ScriptingEnvironment::~ScriptingEnvironment()
void ScriptingEnvironment::initLuaState(lua_State* lua_state)
{
for (unsigned i = 0; i < lua_state_vector.size(); ++i)
luabind::open(lua_state);
// open utility libraries string library;
luaL_openlibs(lua_state);
luaAddScriptFolderToLoadPath(lua_state, file_name.c_str());
// Add our function to the state's global scope
luabind::module(lua_state)[
luabind::def("print", LUA_print<std::string>),
luabind::def("durationIsValid", durationIsValid),
luabind::def("parseDuration", parseDuration)
];
luabind::module(lua_state)[luabind::class_<HashTable<std::string, std::string>>("keyVals")
.def("Add", &HashTable<std::string, std::string>::Add)
.def("Find", &HashTable<std::string, std::string>::Find)
.def("Holds", &HashTable<std::string, std::string>::Holds)];
luabind::module(lua_state)[luabind::class_<ImportNode>("Node")
.def(luabind::constructor<>())
.def_readwrite("lat", &ImportNode::lat)
.def_readwrite("lon", &ImportNode::lon)
.def_readonly("id", &ImportNode::id)
.def_readwrite("bollard", &ImportNode::bollard)
.def_readwrite("traffic_light", &ImportNode::trafficLight)
.def_readwrite("tags", &ImportNode::keyVals)];
luabind::module(lua_state)
[luabind::class_<ExtractionWay>("Way")
.def(luabind::constructor<>())
.def_readonly("id", &ExtractionWay::id)
.def_readwrite("name", &ExtractionWay::name)
.def_readwrite("speed", &ExtractionWay::speed)
.def_readwrite("backward_speed", &ExtractionWay::backward_speed)
.def_readwrite("duration", &ExtractionWay::duration)
.def_readwrite("type", &ExtractionWay::type)
.def_readwrite("access", &ExtractionWay::access)
.def_readwrite("roundabout", &ExtractionWay::roundabout)
.def_readwrite("is_access_restricted", &ExtractionWay::isAccessRestricted)
.def_readwrite("ignore_in_grid", &ExtractionWay::ignoreInGrid)
.def_readwrite("tags", &ExtractionWay::keyVals)
.def_readwrite("direction", &ExtractionWay::direction)
.enum_("constants")[
luabind::value("notSure", 0),
luabind::value("oneway", 1),
luabind::value("bidirectional", 2),
luabind::value("opposite", 3)
]];
// fails on c++11/OS X 10.9
luabind::module(lua_state)[luabind::class_<std::vector<std::string>>("vector").def(
"Add",
static_cast<void (std::vector<std::string>::*)(const std::string &)>(
&std::vector<std::string>::push_back))];
if (0 != luaL_dofile(lua_state, file_name.c_str()))
{
// lua_state_vector[i];
throw OSRMException("ERROR occured in scripting block");
}
}
lua_State *ScriptingEnvironment::getLuaStateForThreadID(const int id) { return lua_state_vector[id]; }
lua_State *ScriptingEnvironment::getLuaState()
{
bool initialized = false;
auto& ref = script_contexts.local(initialized);
if (!initialized)
{
std::shared_ptr<lua_State> state(luaL_newstate(), lua_close);
ref = state;
initLuaState(ref.get());
}
return ref.get();
}
+9 -4
View File
@@ -28,7 +28,9 @@ SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#ifndef SCRIPTINGENVIRONMENT_H_
#define SCRIPTINGENVIRONMENT_H_
#include <vector>
#include <string>
#include <memory>
#include <tbb/enumerable_thread_specific.h>
struct lua_State;
@@ -37,11 +39,14 @@ class ScriptingEnvironment
public:
ScriptingEnvironment();
explicit ScriptingEnvironment(const char *file_name);
virtual ~ScriptingEnvironment();
lua_State *getLuaStateForThreadID(const int);
lua_State *getLuaState();
std::vector<lua_State *> lua_state_vector;
private:
void initLuaState(lua_State* lua_state);
std::string file_name;
tbb::enumerable_thread_specific<std::shared_ptr<lua_State>> script_contexts;
};
#endif /* SCRIPTINGENVIRONMENT_H_ */