Bugfix: actually compute interactions in the right order.

Fix #5.
This commit is contained in:
2018-03-27 16:06:10 +02:00
parent a72bb7542b
commit 8c8d67de73

View File

@@ -106,10 +106,10 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
const auto numEntries = accumulate(shape.begin(), shape.end(), static_cast<size_t>(1), multiplies<void>());
vector<float> interactions(numEntries);
const auto stepSize = shape[0];
const auto stepSize = shape[1];
for (auto i : Range(numEntries / stepSize)) {
caffe::caffe_mul(shape[0], &weights[i * stepSize], data, &interactions[i * stepSize]);
caffe::caffe_mul(shape[1], &weights[i * stepSize], data, &interactions[i * stepSize]);
}
const auto desiredSize = min(INTERACTION_LIMIT, numEntries);
@@ -125,9 +125,26 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
if (abs(interactions[i]) < EPSILON){
break;
}
result.emplace_back(interactions[i], make_pair((i % shape[1]) / normalizer, i / shape[1]));
result.emplace_back(interactions[i], make_pair((i % shape[1]), i / shape[1]));
}
for (auto entry : result) {
if (prevState.data()[entry.second.first] < EPSILON) {
std::cerr << "Error in data!" << entry.first << " "
<< entry.second.first << " " << entry.second.second
<< " " << prevState.data()[entry.second.first] << " "
<< "\n";
}
}
if (normalizer != 1) {
for (auto& entry : result) {
entry.second.first /= normalizer;
}
}
cerr.flush();
return new ActivityAnimation(result, prevPositions.data(), curPositions.data());
}