From 1e9cc7c6c59febf95f839c22ab2310953bb05645 Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Mon, 19 Feb 2018 16:23:47 +0100 Subject: [PATCH] Generalize animation generation. --- src/utils.hpp | 31 +++++++++++++++++++++++++++++++ src/visualisations.cpp | 41 +++++++++++++++++++---------------------- 2 files changed, 50 insertions(+), 22 deletions(-) diff --git a/src/utils.hpp b/src/utils.hpp index 7deaf70..97bbfb9 100644 --- a/src/utils.hpp +++ b/src/utils.hpp @@ -176,4 +176,35 @@ namespace fmri return static_cast(step.count()) / static_cast(modified_length.count()); } + /** + * Perform an argsort partitioning on the first n elements. + * + * @tparam Iter + * @tparam Compare + * @param first First element + * @param middle Sorting limit + * @param last Past end iterator for range + * @param compare Comparison function to use + * @return A vector of the indices before the partitioning cut-off. + */ + template + std::vector arg_nth_element(Iter first, Iter middle, Iter last, Compare compare) + { + using namespace std; + + const auto n = static_cast(distance(first, middle)); + const auto total = static_cast(distance(first, last)); + + vector indices(total); + iota(indices.begin(), indices.end(), 0u); + + nth_element(indices.begin(), indices.begin() + n, indices.end(), [=](size_t a, size_t b) { + return compare(*(first + a), *(first + b)); + }); + + indices.resize(n); + + return indices; + } + } diff --git a/src/visualisations.cpp b/src/visualisations.cpp index 3eac6d3..156fcff 100644 --- a/src/visualisations.cpp +++ b/src/visualisations.cpp @@ -11,7 +11,7 @@ using namespace fmri; using namespace std; // Maximum number of interactions shown -static constexpr size_t INTERACTION_LIMIT = 1000; +static constexpr size_t INTERACTION_LIMIT = 10000; fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer) { @@ -27,12 +27,12 @@ fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData & } } -static vector >> -computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer) +static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, const fmri::LayerInfo &layer, + const vector &prevPositions, const vector &curPositions) { LOG(INFO) << "Computing top interactions for " << layer.name() << endl; - typedef pair > Entry; + typedef pair> Entry; auto data = prevState.data(); @@ -48,38 +48,35 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer) interactions[i] = weights[i] * data[i % shape[0]]; } - // Now use a creative argsort - vector idx(numEntries); - iota(idx.begin(), idx.end(), 0); - const auto desiredSize = min(INTERACTION_LIMIT, numEntries); - nth_element(idx.begin(), idx.begin() + desiredSize, idx.end(), [&interactions](size_t a, size_t b) { - return abs(interactions[a]) > abs(interactions[b]); + auto idx = arg_nth_element(interactions.begin(), interactions.begin() + desiredSize, interactions.end(), [](auto a, auto b) { + return abs(a) > abs(b); }); vector result; result.reserve(desiredSize); - for (auto i : Range(desiredSize)) { - result.emplace_back(interactions[idx[i]], make_pair(idx[i] / shape[0], idx[i] % shape[0])); + for (auto i : idx) { + result.emplace_back(interactions[i], make_pair(i / shape[0], i % shape[0])); } - return result; + return new ActivityAnimation(result, prevPositions.data(), curPositions.data(), -10); } -Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerInfo &layer, - const vector &prevPositions, const vector &curPositions) +Animation *fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerInfo &layer, + const vector &prevPositions, const vector &curPositions) { - if (layer.type() != LayerInfo::Type::InnerProduct) { - // Only supported type at this time - return nullptr; - } - if (prevPositions.empty() || curPositions.empty()) { // Not all positions know, no visualisation possible. return nullptr; } - const auto entries = computeActivityStrengths(prevState, layer); - return new ActivityAnimation(entries, prevPositions.data(), curPositions.data(), -10); + switch (layer.type()) { + case LayerInfo::Type::InnerProduct: + return getFullyConnectedAnimation(prevState, layer, + prevPositions, curPositions); + + default: + return nullptr; + } }