Implement some visualisation for ReLU layers.

This commit is contained in:
2018-02-22 16:20:59 +01:00
parent 3c5358d8d6
commit 86d2b8b48a
2 changed files with 49 additions and 2 deletions

View File

@@ -16,7 +16,6 @@ namespace fmri
Convolutional,
ReLU,
Pooling,
Output,
InnerProduct,
DropOut,
Other

View File

@@ -1,5 +1,5 @@
#include <algorithm>
#include <numeric>
#include <utility>
#include "visualisations.hpp"
#include "DummyLayerVisualisation.hpp"
#include "MultiImageVisualisation.hpp"
@@ -41,6 +41,29 @@ static inline int getNodeNormalizer(const LayerData& layer) {
}
}
/**
* Deduplicate interaction entries.
*
* For duplicate interactions, the interaction strengths are summed.
*
* @param entries
* @return the deduplicated entries.
*/
static EntryList deduplicate(const EntryList& entries)
{
map<pair<size_t, size_t>, float> combiner;
for (auto entry : entries) {
combiner[entry.second] += entry.first;
}
EntryList result;
transform(combiner.begin(), combiner.end(), back_inserter(result), [](const auto& item) {
return make_pair(item.second, item.first);
});
return result;
}
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
{
switch (layer.shape().size()) {
@@ -105,6 +128,29 @@ static Animation *getDropOutAnimation(const fmri::LayerData &prevState,
}
}
results = deduplicate(results);
return new ActivityAnimation(results, prevPositions.data(), curPositions.data(), -10);
}
static Animation *getReLUAnimation(const fmri::LayerData &prevState,
const fmri::LayerData &curState,
const vector<float> &prevPositions,
const vector<float> &curPositions) {
CHECK_EQ(curState.numEntries(), prevState.numEntries()) << "Layers should be of same size!";
const auto prevData = prevState.data(), curData = curState.data();
const auto sourceNormalize = getNodeNormalizer(prevState);
const auto sinkNormalize = getNodeNormalizer(curState);
EntryList results;
for (auto i : Range(curState.numEntries())) {
results.emplace_back(curData[i] - prevData[i], make_pair(i / sourceNormalize, i / sinkNormalize));
}
results = deduplicate(results);
return new ActivityAnimation(results, prevPositions.data(), curPositions.data(), -10);
}
@@ -126,6 +172,8 @@ Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const f
case LayerInfo::Type::DropOut:
return getDropOutAnimation(prevState, curState, prevPositions, curPositions);
case LayerInfo::Type::ReLU:
return getReLUAnimation(prevState, curState, prevPositions, curPositions);
default:
return nullptr;