Implement some visualisation for ReLU layers.
This commit is contained in:
@@ -16,7 +16,6 @@ namespace fmri
|
||||
Convolutional,
|
||||
ReLU,
|
||||
Pooling,
|
||||
Output,
|
||||
InnerProduct,
|
||||
DropOut,
|
||||
Other
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
#include <algorithm>
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include "visualisations.hpp"
|
||||
#include "DummyLayerVisualisation.hpp"
|
||||
#include "MultiImageVisualisation.hpp"
|
||||
@@ -41,6 +41,29 @@ static inline int getNodeNormalizer(const LayerData& layer) {
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Deduplicate interaction entries.
|
||||
*
|
||||
* For duplicate interactions, the interaction strengths are summed.
|
||||
*
|
||||
* @param entries
|
||||
* @return the deduplicated entries.
|
||||
*/
|
||||
static EntryList deduplicate(const EntryList& entries)
|
||||
{
|
||||
map<pair<size_t, size_t>, float> combiner;
|
||||
for (auto entry : entries) {
|
||||
combiner[entry.second] += entry.first;
|
||||
}
|
||||
|
||||
EntryList result;
|
||||
transform(combiner.begin(), combiner.end(), back_inserter(result), [](const auto& item) {
|
||||
return make_pair(item.second, item.first);
|
||||
});
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
|
||||
{
|
||||
switch (layer.shape().size()) {
|
||||
@@ -105,6 +128,29 @@ static Animation *getDropOutAnimation(const fmri::LayerData &prevState,
|
||||
}
|
||||
}
|
||||
|
||||
results = deduplicate(results);
|
||||
|
||||
return new ActivityAnimation(results, prevPositions.data(), curPositions.data(), -10);
|
||||
}
|
||||
|
||||
static Animation *getReLUAnimation(const fmri::LayerData &prevState,
|
||||
const fmri::LayerData &curState,
|
||||
const vector<float> &prevPositions,
|
||||
const vector<float> &curPositions) {
|
||||
CHECK_EQ(curState.numEntries(), prevState.numEntries()) << "Layers should be of same size!";
|
||||
|
||||
const auto prevData = prevState.data(), curData = curState.data();
|
||||
const auto sourceNormalize = getNodeNormalizer(prevState);
|
||||
const auto sinkNormalize = getNodeNormalizer(curState);
|
||||
|
||||
EntryList results;
|
||||
|
||||
for (auto i : Range(curState.numEntries())) {
|
||||
results.emplace_back(curData[i] - prevData[i], make_pair(i / sourceNormalize, i / sinkNormalize));
|
||||
}
|
||||
|
||||
results = deduplicate(results);
|
||||
|
||||
return new ActivityAnimation(results, prevPositions.data(), curPositions.data(), -10);
|
||||
}
|
||||
|
||||
@@ -126,6 +172,8 @@ Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const f
|
||||
case LayerInfo::Type::DropOut:
|
||||
return getDropOutAnimation(prevState, curState, prevPositions, curPositions);
|
||||
|
||||
case LayerInfo::Type::ReLU:
|
||||
return getReLUAnimation(prevState, curState, prevPositions, curPositions);
|
||||
|
||||
default:
|
||||
return nullptr;
|
||||
|
||||
Reference in New Issue
Block a user