Implement visualisation of dropout layers.

This commit is contained in:
2018-02-19 16:51:18 +01:00
parent 6d64830a46
commit 219e8cbec9
3 changed files with 26 additions and 4 deletions

View File

@@ -16,6 +16,8 @@ LayerInfo::Type LayerInfo::typeByName(string_view name)
return Type::Pooling; return Type::Pooling;
} else if (name == "InnerProduct") { } else if (name == "InnerProduct") {
return Type::InnerProduct; return Type::InnerProduct;
} else if (name == "Dropout") {
return Type::DropOut;
} else { } else {
LOG(INFO) << "Received unknown layer type: " << name << endl; LOG(INFO) << "Received unknown layer type: " << name << endl;
return Type::Other; return Type::Other;

View File

@@ -18,6 +18,7 @@ namespace fmri
Pooling, Pooling,
Output, Output,
InnerProduct, InnerProduct,
DropOut,
Other Other
}; };

View File

@@ -13,6 +13,8 @@ using namespace std;
// Maximum number of interactions shown // Maximum number of interactions shown
static constexpr size_t INTERACTION_LIMIT = 10000; static constexpr size_t INTERACTION_LIMIT = 10000;
typedef vector<pair<float, pair<size_t, size_t>>> EntryList;
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer) fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
{ {
switch (layer.shape().size()) { switch (layer.shape().size()) {
@@ -32,8 +34,6 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
{ {
LOG(INFO) << "Computing top interactions for " << layer.name() << endl; LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
typedef pair<DType, pair<size_t, size_t>> Entry;
auto data = prevState.data(); auto data = prevState.data();
CHECK_GE(layer.parameters().size(), 1) << "Layer should have correct parameters"; CHECK_GE(layer.parameters().size(), 1) << "Layer should have correct parameters";
@@ -53,7 +53,7 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
return abs(a) > abs(b); return abs(a) > abs(b);
}); });
vector<Entry> result; EntryList result;
result.reserve(desiredSize); result.reserve(desiredSize);
for (auto i : idx) { for (auto i : idx) {
result.emplace_back(interactions[i], make_pair(i / shape[0], i % shape[0])); result.emplace_back(interactions[i], make_pair(i / shape[0], i % shape[0]));
@@ -62,12 +62,27 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
return new ActivityAnimation(result, prevPositions.data(), curPositions.data(), -10); return new ActivityAnimation(result, prevPositions.data(), curPositions.data(), -10);
} }
static Animation *getDropOutAnimation(const fmri::LayerData &curState,
const vector<float> &prevPositions,
const vector<float> &curPositions) {
auto data = curState.data();
EntryList results;
results.reserve(curState.numEntries());
for (auto i : Range(curState.numEntries())) {
if (data[i] != 0) {
results.emplace_back(data[i], make_pair(i, i));
}
}
return new ActivityAnimation(results, prevPositions.data(), curPositions.data(), -10);
}
Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState, Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState,
const fmri::LayerInfo &layer, const vector<float> &prevPositions, const fmri::LayerInfo &layer, const vector<float> &prevPositions,
const vector<float> &curPositions) const vector<float> &curPositions)
{ {
if (prevPositions.empty() || curPositions.empty()) { if (prevPositions.empty() || curPositions.empty()) {
// Not all positions know, no visualisation possible. // Not all positions known, no visualisation possible.
return nullptr; return nullptr;
} }
@@ -77,6 +92,10 @@ Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const f
return getFullyConnectedAnimation(prevState, layer, return getFullyConnectedAnimation(prevState, layer,
prevPositions, curPositions); prevPositions, curPositions);
case LayerInfo::Type::DropOut:
return getDropOutAnimation(curState, prevPositions, curPositions);
default: default:
return nullptr; return nullptr;
} }