Improved layer name -> type mappings.
This commit is contained in:
@@ -3,24 +3,24 @@
|
||||
using namespace std;
|
||||
using namespace fmri;
|
||||
|
||||
const unordered_map<string_view, LayerInfo::Type> LayerInfo::NAME_TYPE_MAP = {
|
||||
{"Input", Type::Input},
|
||||
{"Convolution", Type::Convolutional},
|
||||
{"ReLU", Type::ReLU},
|
||||
{"Pooling", Type::Pooling},
|
||||
{"InnerProduct", Type::InnerProduct},
|
||||
{"DropOut", Type::DropOut},
|
||||
{"LRN", Type::LRN},
|
||||
{"Split", Type::Split},
|
||||
{"Softmax", Type::Softmax}
|
||||
};
|
||||
|
||||
|
||||
LayerInfo::Type LayerInfo::typeByName(string_view name)
|
||||
{
|
||||
if (name == "Input") {
|
||||
return Type::Input;
|
||||
} else if (name == "Convolution") {
|
||||
return Type::Convolutional;
|
||||
} else if (name == "ReLU") {
|
||||
return Type::ReLU;
|
||||
} else if (name == "Pooling") {
|
||||
return Type::Pooling;
|
||||
} else if (name == "InnerProduct") {
|
||||
return Type::InnerProduct;
|
||||
} else if (name == "Dropout") {
|
||||
return Type::DropOut;
|
||||
} else if (name == "LRN") {
|
||||
return Type::LRN;
|
||||
} else {
|
||||
try {
|
||||
return NAME_TYPE_MAP.at(name);
|
||||
} catch (std::out_of_range &e) {
|
||||
LOG(INFO) << "Received unknown layer type: " << name << endl;
|
||||
return Type::Other;
|
||||
}
|
||||
@@ -46,3 +46,17 @@ const std::vector<boost::shared_ptr<caffe::Blob<DType>>>& LayerInfo::parameters(
|
||||
{
|
||||
return parameters_;
|
||||
}
|
||||
|
||||
std::ostream &fmri::operator<<(std::ostream &out, LayerInfo::Type type)
|
||||
{
|
||||
for (auto i : LayerInfo::NAME_TYPE_MAP) {
|
||||
if (i.second == type) {
|
||||
out << i.first;
|
||||
return out;
|
||||
}
|
||||
}
|
||||
|
||||
out << "ERROR! UNSUPPORTED TYPE";
|
||||
return out;
|
||||
}
|
||||
|
||||
|
||||
@@ -19,6 +19,8 @@ namespace fmri
|
||||
InnerProduct,
|
||||
DropOut,
|
||||
LRN,
|
||||
Split,
|
||||
Softmax,
|
||||
Other
|
||||
};
|
||||
|
||||
@@ -31,9 +33,15 @@ namespace fmri
|
||||
|
||||
static Type typeByName(std::string_view name);
|
||||
|
||||
friend std::ostream& operator<<(std::ostream& out, Type type);
|
||||
|
||||
private:
|
||||
std::vector<boost::shared_ptr<caffe::Blob<DType>>> parameters_;
|
||||
Type type_;
|
||||
std::string name_;
|
||||
|
||||
const static std::unordered_map<std::string_view, Type> NAME_TYPE_MAP;
|
||||
};
|
||||
|
||||
std::ostream& operator<<(std::ostream& out, LayerInfo::Type type);
|
||||
}
|
||||
|
||||
@@ -205,7 +205,9 @@ void Simulator::Impl::computeLayerInfo()
|
||||
|
||||
for (auto i : Range(names.size())) {
|
||||
auto& layer = layers[i];
|
||||
layerInfo_.emplace(names[i], LayerInfo(names[i], layer->type(), layer->blobs()));
|
||||
LayerInfo layerInfo(names[i], layer->type(), layer->blobs());
|
||||
CHECK_NE(layerInfo.type(), LayerInfo::Type::Split) << "Split layers are not supported!";
|
||||
layerInfo_.emplace(names[i], std::move(layerInfo));
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user