Compute the 1000 most interacting nodes in each layer.

This works for fully connected layers only, but it works.
This commit is contained in:
2018-02-12 15:51:50 +01:00
parent 37c6cef733
commit 35ad3df4ca
5 changed files with 86 additions and 5 deletions

View File

@@ -37,3 +37,8 @@ LayerInfo::Type LayerInfo::type() const
{
return type_;
}
const std::vector<boost::shared_ptr<caffe::Blob<DType>>>& LayerInfo::parameters() const
{
return parameters_;
}

View File

@@ -26,6 +26,7 @@ namespace fmri
const std::string& name() const;
Type type() const;
const std::vector<boost::shared_ptr<caffe::Blob<DType>>>& parameters() const;
static Type typeByName(std::string_view name);

View File

@@ -105,11 +105,20 @@ static void updateVisualisers()
{
rendererData.layerVisualisations.clear();
rendererData.animations.clear();
LayerData* prevState = nullptr;
LayerVisualisation* prevVisualisation = nullptr;
for (LayerData &layer : *rendererData.currentData) {
LayerVisualisation* visualisation = getVisualisationForLayer(layer);
if (prevState && prevVisualisation && visualisation) {
auto interaction = getActivityAnimation(*prevState, layer, rendererData.layerInfo.at(layer.name()), prevVisualisation->nodePositions(), visualisation->nodePositions());
rendererData.animations.emplace_back(interaction);
}
rendererData.layerVisualisations.emplace_back(visualisation);
prevVisualisation = visualisation;
prevState = &layer;
}
glutPostRedisplay();

View File

@@ -1,11 +1,15 @@
//
// Created by bert on 09/02/18.
//
#include <utility>
#include "visualisations.hpp"
#include "DummyLayerVisualisation.hpp"
#include "MultiImageVisualisation.hpp"
#include "FlatLayerVisualisation.hpp"
#include "Range.hpp"
using namespace fmri;
using namespace std;
// Maximum number of interactions shown
static constexpr size_t INTERACTION_LIMIT = 1000;
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
{
@@ -20,3 +24,59 @@ fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &
return new DummyLayerVisualisation();
}
}
static vector <pair<DType, pair<size_t, size_t>>>
computeActivityStrengths(const LayerData &prevState, const LayerInfo &layer)
{
typedef pair <DType, pair<size_t, size_t>> Entry;
vector <Entry> result;
auto data = prevState.data();
CHECK_GE(layer.parameters().size(), 1) << "Layer should have correct parameters";
const auto shape = layer.parameters()[0]->shape();
auto weights = layer.parameters()[0]->cpu_data();
for (auto i : Range(accumulate(shape.begin(), shape.end(), 1, multiplies<void>()))) {
result.emplace_back(weights[i] * data[i % shape[0]], make_pair(i % shape[0], i / shape[0]));
}
const auto desiredSize = min(INTERACTION_LIMIT, result.size());
partial_sort(result.begin(), result.begin() + desiredSize, result.end(), [](const Entry &a, const Entry &b) {
return abs(a.first) > abs(b.first);
});
result.resize(desiredSize);
return result;
}
fmri::ActivityAnimation *fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState,
const fmri::LayerInfo &layer, const vector<float> &prevPositions,
const vector<float> &curPositions)
{
if (layer.type() != LayerInfo::Type::InnerProduct) {
// Only supported type at this time
return nullptr;
}
if (prevPositions.empty() || curPositions.empty()) {
// Not all positions know, no visualisation possible.
return nullptr;
}
const auto entries = computeActivityStrengths(prevState, layer);
const auto bufferSize = 3 * entries.size();
unique_ptr<float[]> startingPositions(new float[bufferSize]);
unique_ptr<float[]> endingPositions(new float[bufferSize]);
for (auto i : Range(entries.size())) {
memcpy(startingPositions.get() + 3 * i, prevPositions.data() + 3 * entries[i].second.first, 3 * sizeof(float));
memcpy(endingPositions.get() + 3 * i, curPositions.data() + 3 * entries[i].second.second, 3 * sizeof(float));
}
// TODO: actually do something
return new ActivityAnimation(entries.size(), startingPositions.get(), endingPositions.get(), 2.0);
}

View File

@@ -2,6 +2,8 @@
#include "LayerVisualisation.hpp"
#include "LayerData.hpp"
#include "ActivityAnimation.hpp"
#include "LayerInfo.hpp"
namespace fmri {
/**
@@ -11,4 +13,8 @@ namespace fmri {
* @return A (possibly empty) visualisation. The caller is responsible for deallocating.
*/
LayerVisualisation* getVisualisationForLayer(const LayerData& layer);
}
ActivityAnimation *getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState,
const fmri::LayerInfo &layer, const vector<float> &prevPositions,
const vector<float> &curPositions);
}