Implement layer interactions for softmax layer.

This commit is contained in:
2018-04-06 13:44:54 +02:00
parent 7ab195c462
commit 368cc7c006

View File

@@ -221,6 +221,24 @@ static Animation *getNormalizingAnimation(const fmri::LayerData &prevState, cons
}
}
static Animation *getSoftmaxAnimation(const fmri::LayerData &curState, const vector<float> &prevPositions,
const vector<float> &curPositions)
{
CHECK_EQ(curState.shape().size(), 2) << "Softmax only supported for flat layers.";
std::vector<float> intensities(curState.data(), curState.data() + curState.numEntries());
rescale(intensities.begin(), intensities.end(), 0, 1);
EntryList entries;
for (auto i = 0u; i < intensities.size(); ++i) {
entries.emplace_back(intensities[i], make_pair(i, i));
}
return new ActivityAnimation(entries, prevPositions.data(), curPositions.data(), [](auto i) -> Color {
return {1 - i, 1 - i, 1};
});
}
Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState,
const fmri::LayerInfo &layer, const vector<float> &prevPositions,
const vector<float> &curPositions)
@@ -248,6 +266,9 @@ Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const f
case LayerInfo::Type::LRN:
return getNormalizingAnimation(prevState, curState, prevPositions, curPositions);
case LayerInfo::Type::Softmax:
return getSoftmaxAnimation(curState, prevPositions, curPositions);
default:
return nullptr;
}