From c447a496fde5a3b97f8a6a99abf121e0b6b1171a Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Tue, 13 Mar 2018 15:06:48 +0100 Subject: [PATCH] Improved visualisation of ReLU on images. --- src/visualisations.cpp | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/src/visualisations.cpp b/src/visualisations.cpp index 1c5856d..771d16b 100644 --- a/src/visualisations.cpp +++ b/src/visualisations.cpp @@ -154,27 +154,28 @@ static Animation *getReLUAnimation(const fmri::LayerData &prevState, const vector &curPositions) { CHECK_EQ(curState.numEntries(), prevState.numEntries()) << "Layers should be of same size!"; - const auto prevData = prevState.data(), curData = curState.data(); - const auto sourceNormalize = getNodeNormalizer(prevState); - const auto sinkNormalize = getNodeNormalizer(curState); + std::vector changes(prevState.numEntries()); + caffe::caffe_sub(prevState.numEntries(), curState.data(), prevState.data(), changes.data()); - EntryList results; - - for (auto i : Range(curState.numEntries())) { - results.emplace_back(curData[i] - prevData[i], make_pair(i / sourceNormalize, i / sinkNormalize)); - } - - results = deduplicate(results); - - const auto maxValue = max_element(results.begin(), results.end())->first; - - return new ActivityAnimation(results, prevPositions.data(), curPositions.data(), -10, [=](float i) -> ActivityAnimation::Color { - if (maxValue == 0) { - return {1, 1, 1}; - } else { - return {1 - i / maxValue, 1 - i / maxValue, 1}; + if (curState.shape().size() == 2) { + EntryList results; + for (auto i : Range(curState.numEntries())) { + results.emplace_back(changes[i], make_pair(i, i)); } - }); + + const auto maxValue = max_element(results.begin(), results.end())->first; + + return new ActivityAnimation(results, prevPositions.data(), curPositions.data(), -10, + [=](float i) -> ActivityAnimation::Color { + if (maxValue == 0) { + return {1, 1, 1}; + } else { + return {1 - i / maxValue, 1 - i / maxValue, 1}; + } + }); + } else { + return new ImageInteractionAnimation(changes.data(), prevState.shape(), prevPositions, curPositions, -10); + } } static Animation *getNormalizingAnimation(const fmri::LayerData &prevState, const LayerData &curState,