Optimize activity strength computation.
This commit is contained in:
@@ -1,3 +1,4 @@
|
|||||||
|
#include <numeric>
|
||||||
#include <utility>
|
#include <utility>
|
||||||
#include "visualisations.hpp"
|
#include "visualisations.hpp"
|
||||||
#include "DummyLayerVisualisation.hpp"
|
#include "DummyLayerVisualisation.hpp"
|
||||||
@@ -32,7 +33,6 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
|
|||||||
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
|
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
|
||||||
|
|
||||||
typedef pair <DType, pair<size_t, size_t>> Entry;
|
typedef pair <DType, pair<size_t, size_t>> Entry;
|
||||||
vector <Entry> result;
|
|
||||||
|
|
||||||
auto data = prevState.data();
|
auto data = prevState.data();
|
||||||
|
|
||||||
@@ -40,19 +40,28 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
|
|||||||
|
|
||||||
const auto shape = layer.parameters()[0]->shape();
|
const auto shape = layer.parameters()[0]->shape();
|
||||||
auto weights = layer.parameters()[0]->cpu_data();
|
auto weights = layer.parameters()[0]->cpu_data();
|
||||||
const auto numEntries = accumulate(shape.begin(), shape.end(), 1u, multiplies<void>());
|
const auto numEntries = accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), multiplies<void>());
|
||||||
result.reserve(numEntries);
|
|
||||||
|
vector<float> interactions(numEntries);
|
||||||
|
|
||||||
for (auto i : Range(numEntries)) {
|
for (auto i : Range(numEntries)) {
|
||||||
result.emplace_back(weights[i] * data[i % shape[0]], make_pair(i % shape[0], i / shape[0]));
|
interactions[i] = weights[i] * data[i % shape[0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto desiredSize = min(INTERACTION_LIMIT, result.size());
|
// Now use a creative argsort
|
||||||
partial_sort(result.begin(), result.begin() + desiredSize, result.end(), [](const Entry &a, const Entry &b) {
|
vector<size_t> idx(numEntries);
|
||||||
return abs(a.first) > abs(b.first);
|
iota(idx.begin(), idx.end(), 0);
|
||||||
|
|
||||||
|
const auto desiredSize = min(INTERACTION_LIMIT, numEntries);
|
||||||
|
partial_sort(idx.begin(), idx.begin() + desiredSize, idx.end(), [&interactions](size_t a, size_t b) {
|
||||||
|
return abs(interactions[a]) > abs(interactions[b]);
|
||||||
});
|
});
|
||||||
|
|
||||||
result.resize(desiredSize);
|
vector<Entry> result;
|
||||||
|
result.reserve(desiredSize);
|
||||||
|
for (auto i : Range(desiredSize)) {
|
||||||
|
result.emplace_back(interactions[idx[i]], make_pair(idx[i] % shape[0], idx[i] / shape[0]));
|
||||||
|
}
|
||||||
|
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user