Optimize activity strength computation.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
#include <numeric>
|
||||
#include <utility>
|
||||
#include "visualisations.hpp"
|
||||
#include "DummyLayerVisualisation.hpp"
|
||||
@@ -32,7 +33,6 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
|
||||
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
|
||||
|
||||
typedef pair <DType, pair<size_t, size_t>> Entry;
|
||||
vector <Entry> result;
|
||||
|
||||
auto data = prevState.data();
|
||||
|
||||
@@ -40,19 +40,28 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
|
||||
|
||||
const auto shape = layer.parameters()[0]->shape();
|
||||
auto weights = layer.parameters()[0]->cpu_data();
|
||||
const auto numEntries = accumulate(shape.begin(), shape.end(), 1u, multiplies<void>());
|
||||
result.reserve(numEntries);
|
||||
const auto numEntries = accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), multiplies<void>());
|
||||
|
||||
vector<float> interactions(numEntries);
|
||||
|
||||
for (auto i : Range(numEntries)) {
|
||||
result.emplace_back(weights[i] * data[i % shape[0]], make_pair(i % shape[0], i / shape[0]));
|
||||
interactions[i] = weights[i] * data[i % shape[0]];
|
||||
}
|
||||
|
||||
const auto desiredSize = min(INTERACTION_LIMIT, result.size());
|
||||
partial_sort(result.begin(), result.begin() + desiredSize, result.end(), [](const Entry &a, const Entry &b) {
|
||||
return abs(a.first) > abs(b.first);
|
||||
// Now use a creative argsort
|
||||
vector<size_t> idx(numEntries);
|
||||
iota(idx.begin(), idx.end(), 0);
|
||||
|
||||
const auto desiredSize = min(INTERACTION_LIMIT, numEntries);
|
||||
partial_sort(idx.begin(), idx.begin() + desiredSize, idx.end(), [&interactions](size_t a, size_t b) {
|
||||
return abs(interactions[a]) > abs(interactions[b]);
|
||||
});
|
||||
|
||||
result.resize(desiredSize);
|
||||
vector<Entry> result;
|
||||
result.reserve(desiredSize);
|
||||
for (auto i : Range(desiredSize)) {
|
||||
result.emplace_back(interactions[idx[i]], make_pair(idx[i] % shape[0], idx[i] / shape[0]));
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user