Optimize activity strength computation.

This commit is contained in:
2018-02-15 18:06:44 +01:00
parent e4336b6757
commit 9d2e7c3104

View File

@@ -1,3 +1,4 @@
#include <numeric>
#include <utility>
#include "visualisations.hpp"
#include "DummyLayerVisualisation.hpp"
@@ -32,7 +33,6 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
typedef pair <DType, pair<size_t, size_t>> Entry;
vector <Entry> result;
auto data = prevState.data();
@@ -40,19 +40,28 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
const auto shape = layer.parameters()[0]->shape();
auto weights = layer.parameters()[0]->cpu_data();
const auto numEntries = accumulate(shape.begin(), shape.end(), 1u, multiplies<void>());
result.reserve(numEntries);
const auto numEntries = accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), multiplies<void>());
vector<float> interactions(numEntries);
for (auto i : Range(numEntries)) {
result.emplace_back(weights[i] * data[i % shape[0]], make_pair(i % shape[0], i / shape[0]));
interactions[i] = weights[i] * data[i % shape[0]];
}
const auto desiredSize = min(INTERACTION_LIMIT, result.size());
partial_sort(result.begin(), result.begin() + desiredSize, result.end(), [](const Entry &a, const Entry &b) {
return abs(a.first) > abs(b.first);
// Now use a creative argsort
vector<size_t> idx(numEntries);
iota(idx.begin(), idx.end(), 0);
const auto desiredSize = min(INTERACTION_LIMIT, numEntries);
partial_sort(idx.begin(), idx.begin() + desiredSize, idx.end(), [&interactions](size_t a, size_t b) {
return abs(interactions[a]) > abs(interactions[b]);
});
result.resize(desiredSize);
vector<Entry> result;
result.reserve(desiredSize);
for (auto i : Range(desiredSize)) {
result.emplace_back(interactions[idx[i]], make_pair(idx[i] % shape[0], idx[i] / shape[0]));
}
return result;
}