From 229b7c8004a3a397d8033f2f9a8b8be0ad6e9f42 Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Tue, 20 Mar 2018 14:22:15 +0100 Subject: [PATCH] Improved layer name -> type mappings. --- src/LayerInfo.cpp | 44 +++++++++++++++++++++++++++++--------------- src/LayerInfo.hpp | 8 ++++++++ src/Simulator.cpp | 4 +++- 3 files changed, 40 insertions(+), 16 deletions(-) diff --git a/src/LayerInfo.cpp b/src/LayerInfo.cpp index a5686e4..8741053 100644 --- a/src/LayerInfo.cpp +++ b/src/LayerInfo.cpp @@ -3,24 +3,24 @@ using namespace std; using namespace fmri; +const unordered_map LayerInfo::NAME_TYPE_MAP = { + {"Input", Type::Input}, + {"Convolution", Type::Convolutional}, + {"ReLU", Type::ReLU}, + {"Pooling", Type::Pooling}, + {"InnerProduct", Type::InnerProduct}, + {"DropOut", Type::DropOut}, + {"LRN", Type::LRN}, + {"Split", Type::Split}, + {"Softmax", Type::Softmax} +}; + LayerInfo::Type LayerInfo::typeByName(string_view name) { - if (name == "Input") { - return Type::Input; - } else if (name == "Convolution") { - return Type::Convolutional; - } else if (name == "ReLU") { - return Type::ReLU; - } else if (name == "Pooling") { - return Type::Pooling; - } else if (name == "InnerProduct") { - return Type::InnerProduct; - } else if (name == "Dropout") { - return Type::DropOut; - } else if (name == "LRN") { - return Type::LRN; - } else { + try { + return NAME_TYPE_MAP.at(name); + } catch (std::out_of_range &e) { LOG(INFO) << "Received unknown layer type: " << name << endl; return Type::Other; } @@ -46,3 +46,17 @@ const std::vector>>& LayerInfo::parameters( { return parameters_; } + +std::ostream &fmri::operator<<(std::ostream &out, LayerInfo::Type type) +{ + for (auto i : LayerInfo::NAME_TYPE_MAP) { + if (i.second == type) { + out << i.first; + return out; + } + } + + out << "ERROR! UNSUPPORTED TYPE"; + return out; +} + diff --git a/src/LayerInfo.hpp b/src/LayerInfo.hpp index 60050a5..7a065f5 100644 --- a/src/LayerInfo.hpp +++ b/src/LayerInfo.hpp @@ -19,6 +19,8 @@ namespace fmri InnerProduct, DropOut, LRN, + Split, + Softmax, Other }; @@ -31,9 +33,15 @@ namespace fmri static Type typeByName(std::string_view name); + friend std::ostream& operator<<(std::ostream& out, Type type); + private: std::vector>> parameters_; Type type_; std::string name_; + + const static std::unordered_map NAME_TYPE_MAP; }; + + std::ostream& operator<<(std::ostream& out, LayerInfo::Type type); } diff --git a/src/Simulator.cpp b/src/Simulator.cpp index 21cd0bb..86e2735 100644 --- a/src/Simulator.cpp +++ b/src/Simulator.cpp @@ -205,7 +205,9 @@ void Simulator::Impl::computeLayerInfo() for (auto i : Range(names.size())) { auto& layer = layers[i]; - layerInfo_.emplace(names[i], LayerInfo(names[i], layer->type(), layer->blobs())); + LayerInfo layerInfo(names[i], layer->type(), layer->blobs()); + CHECK_NE(layerInfo.type(), LayerInfo::Type::Split) << "Split layers are not supported!"; + layerInfo_.emplace(names[i], std::move(layerInfo)); } }