From 9d2e7c3104d9cefda6934372e50541e069e9a0f8 Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Thu, 15 Feb 2018 18:06:44 +0100 Subject: [PATCH] Optimize activity strength computation. --- src/visualisations.cpp | 25 +++++++++++++++++-------- 1 file changed, 17 insertions(+), 8 deletions(-) diff --git a/src/visualisations.cpp b/src/visualisations.cpp index 2d68faf..aa72cea 100644 --- a/src/visualisations.cpp +++ b/src/visualisations.cpp @@ -1,3 +1,4 @@ +#include #include #include "visualisations.hpp" #include "DummyLayerVisualisation.hpp" @@ -32,7 +33,6 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer) LOG(INFO) << "Computing top interactions for " << layer.name() << endl; typedef pair > Entry; - vector result; auto data = prevState.data(); @@ -40,19 +40,28 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer) const auto shape = layer.parameters()[0]->shape(); auto weights = layer.parameters()[0]->cpu_data(); - const auto numEntries = accumulate(shape.begin(), shape.end(), 1u, multiplies()); - result.reserve(numEntries); + const auto numEntries = accumulate(shape.begin(), shape.end(), static_cast(1), multiplies()); + + vector interactions(numEntries); for (auto i : Range(numEntries)) { - result.emplace_back(weights[i] * data[i % shape[0]], make_pair(i % shape[0], i / shape[0])); + interactions[i] = weights[i] * data[i % shape[0]]; } - const auto desiredSize = min(INTERACTION_LIMIT, result.size()); - partial_sort(result.begin(), result.begin() + desiredSize, result.end(), [](const Entry &a, const Entry &b) { - return abs(a.first) > abs(b.first); + // Now use a creative argsort + vector idx(numEntries); + iota(idx.begin(), idx.end(), 0); + + const auto desiredSize = min(INTERACTION_LIMIT, numEntries); + partial_sort(idx.begin(), idx.begin() + desiredSize, idx.end(), [&interactions](size_t a, size_t b) { + return abs(interactions[a]) > abs(interactions[b]); }); - result.resize(desiredSize); + vector result; + result.reserve(desiredSize); + for (auto i : Range(desiredSize)) { + result.emplace_back(interactions[idx[i]], make_pair(idx[i] % shape[0], idx[i] / shape[0])); + } return result; }