Generalize animation generation.
This commit is contained in:
@@ -176,4 +176,35 @@ namespace fmri
|
|||||||
return static_cast<float>(step.count()) / static_cast<float>(modified_length.count());
|
return static_cast<float>(step.count()) / static_cast<float>(modified_length.count());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Perform an argsort partitioning on the first n elements.
|
||||||
|
*
|
||||||
|
* @tparam Iter
|
||||||
|
* @tparam Compare
|
||||||
|
* @param first First element
|
||||||
|
* @param middle Sorting limit
|
||||||
|
* @param last Past end iterator for range
|
||||||
|
* @param compare Comparison function to use
|
||||||
|
* @return A vector of the indices before the partitioning cut-off.
|
||||||
|
*/
|
||||||
|
template<class Iter, class Compare>
|
||||||
|
std::vector<std::size_t> arg_nth_element(Iter first, Iter middle, Iter last, Compare compare)
|
||||||
|
{
|
||||||
|
using namespace std;
|
||||||
|
|
||||||
|
const auto n = static_cast<size_t>(distance(first, middle));
|
||||||
|
const auto total = static_cast<size_t>(distance(first, last));
|
||||||
|
|
||||||
|
vector<size_t> indices(total);
|
||||||
|
iota(indices.begin(), indices.end(), 0u);
|
||||||
|
|
||||||
|
nth_element(indices.begin(), indices.begin() + n, indices.end(), [=](size_t a, size_t b) {
|
||||||
|
return compare(*(first + a), *(first + b));
|
||||||
|
});
|
||||||
|
|
||||||
|
indices.resize(n);
|
||||||
|
|
||||||
|
return indices;
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ using namespace fmri;
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
// Maximum number of interactions shown
|
// Maximum number of interactions shown
|
||||||
static constexpr size_t INTERACTION_LIMIT = 1000;
|
static constexpr size_t INTERACTION_LIMIT = 10000;
|
||||||
|
|
||||||
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
|
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
|
||||||
{
|
{
|
||||||
@@ -27,12 +27,12 @@ fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
static vector <pair<DType, pair<size_t, size_t>>>
|
static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, const fmri::LayerInfo &layer,
|
||||||
computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
|
const vector<float> &prevPositions, const vector<float> &curPositions)
|
||||||
{
|
{
|
||||||
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
|
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
|
||||||
|
|
||||||
typedef pair <DType, pair<size_t, size_t>> Entry;
|
typedef pair<DType, pair<size_t, size_t>> Entry;
|
||||||
|
|
||||||
auto data = prevState.data();
|
auto data = prevState.data();
|
||||||
|
|
||||||
@@ -48,38 +48,35 @@ computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
|
|||||||
interactions[i] = weights[i] * data[i % shape[0]];
|
interactions[i] = weights[i] * data[i % shape[0]];
|
||||||
}
|
}
|
||||||
|
|
||||||
// Now use a creative argsort
|
|
||||||
vector<size_t> idx(numEntries);
|
|
||||||
iota(idx.begin(), idx.end(), 0);
|
|
||||||
|
|
||||||
const auto desiredSize = min(INTERACTION_LIMIT, numEntries);
|
const auto desiredSize = min(INTERACTION_LIMIT, numEntries);
|
||||||
nth_element(idx.begin(), idx.begin() + desiredSize, idx.end(), [&interactions](size_t a, size_t b) {
|
auto idx = arg_nth_element(interactions.begin(), interactions.begin() + desiredSize, interactions.end(), [](auto a, auto b) {
|
||||||
return abs(interactions[a]) > abs(interactions[b]);
|
return abs(a) > abs(b);
|
||||||
});
|
});
|
||||||
|
|
||||||
vector<Entry> result;
|
vector<Entry> result;
|
||||||
result.reserve(desiredSize);
|
result.reserve(desiredSize);
|
||||||
for (auto i : Range(desiredSize)) {
|
for (auto i : idx) {
|
||||||
result.emplace_back(interactions[idx[i]], make_pair(idx[i] / shape[0], idx[i] % shape[0]));
|
result.emplace_back(interactions[i], make_pair(i / shape[0], i % shape[0]));
|
||||||
}
|
}
|
||||||
|
|
||||||
return result;
|
return new ActivityAnimation(result, prevPositions.data(), curPositions.data(), -10);
|
||||||
}
|
}
|
||||||
|
|
||||||
Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerInfo &layer,
|
Animation *fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerInfo &layer,
|
||||||
const vector<float> &prevPositions, const vector<float> &curPositions)
|
const vector<float> &prevPositions, const vector<float> &curPositions)
|
||||||
{
|
{
|
||||||
if (layer.type() != LayerInfo::Type::InnerProduct) {
|
|
||||||
// Only supported type at this time
|
|
||||||
return nullptr;
|
|
||||||
}
|
|
||||||
|
|
||||||
if (prevPositions.empty() || curPositions.empty()) {
|
if (prevPositions.empty() || curPositions.empty()) {
|
||||||
// Not all positions know, no visualisation possible.
|
// Not all positions know, no visualisation possible.
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
const auto entries = computeActivityStrengths(prevState, layer);
|
|
||||||
|
|
||||||
return new ActivityAnimation(entries, prevPositions.data(), curPositions.data(), -10);
|
switch (layer.type()) {
|
||||||
|
case LayerInfo::Type::InnerProduct:
|
||||||
|
return getFullyConnectedAnimation(prevState, layer,
|
||||||
|
prevPositions, curPositions);
|
||||||
|
|
||||||
|
default:
|
||||||
|
return nullptr;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user