From 35ad3df4ca98653ac1ce70185b3a401af8bce669 Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Mon, 12 Feb 2018 15:51:50 +0100 Subject: [PATCH] Compute the 1000 most interacting nodes in each layer. This works for fully connected layers only, but it works. --- src/LayerInfo.cpp | 5 ++++ src/LayerInfo.hpp | 1 + src/main.cpp | 9 ++++++ src/visualisations.cpp | 68 +++++++++++++++++++++++++++++++++++++++--- src/visualisations.hpp | 8 ++++- 5 files changed, 86 insertions(+), 5 deletions(-) diff --git a/src/LayerInfo.cpp b/src/LayerInfo.cpp index 47d14c3..a6c63bd 100644 --- a/src/LayerInfo.cpp +++ b/src/LayerInfo.cpp @@ -37,3 +37,8 @@ LayerInfo::Type LayerInfo::type() const { return type_; } + +const std::vector>>& LayerInfo::parameters() const +{ + return parameters_; +} diff --git a/src/LayerInfo.hpp b/src/LayerInfo.hpp index 9692b4b..13ef551 100644 --- a/src/LayerInfo.hpp +++ b/src/LayerInfo.hpp @@ -26,6 +26,7 @@ namespace fmri const std::string& name() const; Type type() const; + const std::vector>>& parameters() const; static Type typeByName(std::string_view name); diff --git a/src/main.cpp b/src/main.cpp index 573a54d..6e46145 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -105,11 +105,20 @@ static void updateVisualisers() { rendererData.layerVisualisations.clear(); rendererData.animations.clear(); + LayerData* prevState = nullptr; + LayerVisualisation* prevVisualisation = nullptr; for (LayerData &layer : *rendererData.currentData) { LayerVisualisation* visualisation = getVisualisationForLayer(layer); + if (prevState && prevVisualisation && visualisation) { + auto interaction = getActivityAnimation(*prevState, layer, rendererData.layerInfo.at(layer.name()), prevVisualisation->nodePositions(), visualisation->nodePositions()); + rendererData.animations.emplace_back(interaction); + } rendererData.layerVisualisations.emplace_back(visualisation); + + prevVisualisation = visualisation; + prevState = &layer; } glutPostRedisplay(); diff --git a/src/visualisations.cpp b/src/visualisations.cpp index be97cde..88f0ce4 100644 --- a/src/visualisations.cpp +++ b/src/visualisations.cpp @@ -1,11 +1,15 @@ -// -// Created by bert on 09/02/18. -// - +#include #include "visualisations.hpp" #include "DummyLayerVisualisation.hpp" #include "MultiImageVisualisation.hpp" #include "FlatLayerVisualisation.hpp" +#include "Range.hpp" + +using namespace fmri; +using namespace std; + +// Maximum number of interactions shown +static constexpr size_t INTERACTION_LIMIT = 1000; fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer) { @@ -20,3 +24,59 @@ fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData & return new DummyLayerVisualisation(); } } + +static vector >> +computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer) +{ + typedef pair > Entry; + vector result; + + auto data = prevState.data(); + + CHECK_GE(layer.parameters().size(), 1) << "Layer should have correct parameters"; + + const auto shape = layer.parameters()[0]->shape(); + auto weights = layer.parameters()[0]->cpu_data(); + + for (auto i : Range(accumulate(shape.begin(), shape.end(), 1, multiplies()))) { + result.emplace_back(weights[i] * data[i % shape[0]], make_pair(i % shape[0], i / shape[0])); + } + + const auto desiredSize = min(INTERACTION_LIMIT, result.size()); + partial_sort(result.begin(), result.begin() + desiredSize, result.end(), [](const Entry &a, const Entry &b) { + return abs(a.first) > abs(b.first); + }); + + result.resize(desiredSize); + + return result; +} + +fmri::ActivityAnimation *fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState, + const fmri::LayerInfo &layer, const vector &prevPositions, + const vector &curPositions) +{ + if (layer.type() != LayerInfo::Type::InnerProduct) { + // Only supported type at this time + return nullptr; + } + + if (prevPositions.empty() || curPositions.empty()) { + // Not all positions know, no visualisation possible. + return nullptr; + } + + const auto entries = computeActivityStrengths(prevState, layer); + const auto bufferSize = 3 * entries.size(); + + unique_ptr startingPositions(new float[bufferSize]); + unique_ptr endingPositions(new float[bufferSize]); + + for (auto i : Range(entries.size())) { + memcpy(startingPositions.get() + 3 * i, prevPositions.data() + 3 * entries[i].second.first, 3 * sizeof(float)); + memcpy(endingPositions.get() + 3 * i, curPositions.data() + 3 * entries[i].second.second, 3 * sizeof(float)); + } + + // TODO: actually do something + return new ActivityAnimation(entries.size(), startingPositions.get(), endingPositions.get(), 2.0); +} diff --git a/src/visualisations.hpp b/src/visualisations.hpp index 7ea8f40..9a5549c 100644 --- a/src/visualisations.hpp +++ b/src/visualisations.hpp @@ -2,6 +2,8 @@ #include "LayerVisualisation.hpp" #include "LayerData.hpp" +#include "ActivityAnimation.hpp" +#include "LayerInfo.hpp" namespace fmri { /** @@ -11,4 +13,8 @@ namespace fmri { * @return A (possibly empty) visualisation. The caller is responsible for deallocating. */ LayerVisualisation* getVisualisationForLayer(const LayerData& layer); -} \ No newline at end of file + + ActivityAnimation *getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState, + const fmri::LayerInfo &layer, const vector &prevPositions, + const vector &curPositions); +}