Improved layer name -> type mappings.

This commit is contained in:
2018-03-20 14:22:15 +01:00
parent a4e58e22bc
commit 229b7c8004
3 changed files with 40 additions and 16 deletions

View File

@@ -3,24 +3,24 @@
using namespace std; using namespace std;
using namespace fmri; using namespace fmri;
const unordered_map<string_view, LayerInfo::Type> LayerInfo::NAME_TYPE_MAP = {
{"Input", Type::Input},
{"Convolution", Type::Convolutional},
{"ReLU", Type::ReLU},
{"Pooling", Type::Pooling},
{"InnerProduct", Type::InnerProduct},
{"DropOut", Type::DropOut},
{"LRN", Type::LRN},
{"Split", Type::Split},
{"Softmax", Type::Softmax}
};
LayerInfo::Type LayerInfo::typeByName(string_view name) LayerInfo::Type LayerInfo::typeByName(string_view name)
{ {
if (name == "Input") { try {
return Type::Input; return NAME_TYPE_MAP.at(name);
} else if (name == "Convolution") { } catch (std::out_of_range &e) {
return Type::Convolutional;
} else if (name == "ReLU") {
return Type::ReLU;
} else if (name == "Pooling") {
return Type::Pooling;
} else if (name == "InnerProduct") {
return Type::InnerProduct;
} else if (name == "Dropout") {
return Type::DropOut;
} else if (name == "LRN") {
return Type::LRN;
} else {
LOG(INFO) << "Received unknown layer type: " << name << endl; LOG(INFO) << "Received unknown layer type: " << name << endl;
return Type::Other; return Type::Other;
} }
@@ -46,3 +46,17 @@ const std::vector<boost::shared_ptr<caffe::Blob<DType>>>& LayerInfo::parameters(
{ {
return parameters_; return parameters_;
} }
std::ostream &fmri::operator<<(std::ostream &out, LayerInfo::Type type)
{
for (auto i : LayerInfo::NAME_TYPE_MAP) {
if (i.second == type) {
out << i.first;
return out;
}
}
out << "ERROR! UNSUPPORTED TYPE";
return out;
}

View File

@@ -19,6 +19,8 @@ namespace fmri
InnerProduct, InnerProduct,
DropOut, DropOut,
LRN, LRN,
Split,
Softmax,
Other Other
}; };
@@ -31,9 +33,15 @@ namespace fmri
static Type typeByName(std::string_view name); static Type typeByName(std::string_view name);
friend std::ostream& operator<<(std::ostream& out, Type type);
private: private:
std::vector<boost::shared_ptr<caffe::Blob<DType>>> parameters_; std::vector<boost::shared_ptr<caffe::Blob<DType>>> parameters_;
Type type_; Type type_;
std::string name_; std::string name_;
const static std::unordered_map<std::string_view, Type> NAME_TYPE_MAP;
}; };
std::ostream& operator<<(std::ostream& out, LayerInfo::Type type);
} }

View File

@@ -205,7 +205,9 @@ void Simulator::Impl::computeLayerInfo()
for (auto i : Range(names.size())) { for (auto i : Range(names.size())) {
auto& layer = layers[i]; auto& layer = layers[i];
layerInfo_.emplace(names[i], LayerInfo(names[i], layer->type(), layer->blobs())); LayerInfo layerInfo(names[i], layer->type(), layer->blobs());
CHECK_NE(layerInfo.type(), LayerInfo::Type::Split) << "Split layers are not supported!";
layerInfo_.emplace(names[i], std::move(layerInfo));
} }
} }