Fix out-of-bounds read on node coordinates.

This commit is contained in:
2018-02-20 12:22:34 +01:00
parent 219e8cbec9
commit 3c5358d8d6

View File

@@ -15,6 +15,32 @@ static constexpr size_t INTERACTION_LIMIT = 10000;
typedef vector<pair<float, pair<size_t, size_t>>> EntryList; typedef vector<pair<float, pair<size_t, size_t>>> EntryList;
/**
* Normalizer for node positions.
*
* Since not every neuron in a layer may get a node in the visualisation,
* this function maps those neurons back to a node number that does.
*
* Usage: node / getNodeNormalizer(layer).
*
* @param layer Layer to compute normalization for
* @return Number to divide node numbers by.
*/
static inline int getNodeNormalizer(const LayerData& layer) {
const auto& shape = layer.shape();
switch(shape.size()) {
case 2:
return 1;
case 4:
return shape[2] * shape[3];
default:
CHECK(false) << "Unsupported shape " << shape.size() << endl;
exit(EINVAL);
}
}
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer) fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
{ {
switch (layer.shape().size()) { switch (layer.shape().size()) {
@@ -55,22 +81,27 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
EntryList result; EntryList result;
result.reserve(desiredSize); result.reserve(desiredSize);
const auto normalizer = getNodeNormalizer(prevState);
for (auto i : idx) { for (auto i : idx) {
result.emplace_back(interactions[i], make_pair(i / shape[0], i % shape[0])); result.emplace_back(interactions[i], make_pair(i / shape[0] / normalizer, i % shape[0]));
} }
return new ActivityAnimation(result, prevPositions.data(), curPositions.data(), -10); return new ActivityAnimation(result, prevPositions.data(), curPositions.data(), -10);
} }
static Animation *getDropOutAnimation(const fmri::LayerData &curState, static Animation *getDropOutAnimation(const fmri::LayerData &prevState,
const fmri::LayerData &curState,
const vector<float> &prevPositions, const vector<float> &prevPositions,
const vector<float> &curPositions) { const vector<float> &curPositions) {
const auto sourceNormalize = getNodeNormalizer(prevState);
const auto sinkNormalize = getNodeNormalizer(curState);
auto data = curState.data(); auto data = curState.data();
EntryList results; EntryList results;
results.reserve(curState.numEntries()); results.reserve(curState.numEntries());
for (auto i : Range(curState.numEntries())) { for (auto i : Range(curState.numEntries())) {
if (data[i] != 0) { if (data[i] != 0) {
results.emplace_back(data[i], make_pair(i, i)); results.emplace_back(data[i], make_pair(i / sourceNormalize, i / sinkNormalize));
} }
} }
@@ -93,7 +124,7 @@ Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const f
prevPositions, curPositions); prevPositions, curPositions);
case LayerInfo::Type::DropOut: case LayerInfo::Type::DropOut:
return getDropOutAnimation(curState, prevPositions, curPositions); return getDropOutAnimation(prevState, curState, prevPositions, curPositions);
default: default: