Improved layer name -> type mappings.

This commit is contained in:
2018-03-20 14:22:15 +01:00
parent a4e58e22bc
commit 229b7c8004
3 changed files with 40 additions and 16 deletions

View File

@@ -3,24 +3,24 @@
using namespace std;
using namespace fmri;
const unordered_map<string_view, LayerInfo::Type> LayerInfo::NAME_TYPE_MAP = {
{"Input", Type::Input},
{"Convolution", Type::Convolutional},
{"ReLU", Type::ReLU},
{"Pooling", Type::Pooling},
{"InnerProduct", Type::InnerProduct},
{"DropOut", Type::DropOut},
{"LRN", Type::LRN},
{"Split", Type::Split},
{"Softmax", Type::Softmax}
};
LayerInfo::Type LayerInfo::typeByName(string_view name)
{
if (name == "Input") {
return Type::Input;
} else if (name == "Convolution") {
return Type::Convolutional;
} else if (name == "ReLU") {
return Type::ReLU;
} else if (name == "Pooling") {
return Type::Pooling;
} else if (name == "InnerProduct") {
return Type::InnerProduct;
} else if (name == "Dropout") {
return Type::DropOut;
} else if (name == "LRN") {
return Type::LRN;
} else {
try {
return NAME_TYPE_MAP.at(name);
} catch (std::out_of_range &e) {
LOG(INFO) << "Received unknown layer type: " << name << endl;
return Type::Other;
}
@@ -46,3 +46,17 @@ const std::vector<boost::shared_ptr<caffe::Blob<DType>>>& LayerInfo::parameters(
{
return parameters_;
}
std::ostream &fmri::operator<<(std::ostream &out, LayerInfo::Type type)
{
for (auto i : LayerInfo::NAME_TYPE_MAP) {
if (i.second == type) {
out << i.first;
return out;
}
}
out << "ERROR! UNSUPPORTED TYPE";
return out;
}

View File

@@ -19,6 +19,8 @@ namespace fmri
InnerProduct,
DropOut,
LRN,
Split,
Softmax,
Other
};
@@ -31,9 +33,15 @@ namespace fmri
static Type typeByName(std::string_view name);
friend std::ostream& operator<<(std::ostream& out, Type type);
private:
std::vector<boost::shared_ptr<caffe::Blob<DType>>> parameters_;
Type type_;
std::string name_;
const static std::unordered_map<std::string_view, Type> NAME_TYPE_MAP;
};
std::ostream& operator<<(std::ostream& out, LayerInfo::Type type);
}

View File

@@ -205,7 +205,9 @@ void Simulator::Impl::computeLayerInfo()
for (auto i : Range(names.size())) {
auto& layer = layers[i];
layerInfo_.emplace(names[i], LayerInfo(names[i], layer->type(), layer->blobs()));
LayerInfo layerInfo(names[i], layer->type(), layer->blobs());
CHECK_NE(layerInfo.type(), LayerInfo::Type::Split) << "Split layers are not supported!";
layerInfo_.emplace(names[i], std::move(layerInfo));
}
}