diff --git a/src/fmri/visualisations.cpp b/src/fmri/visualisations.cpp index 40a8358..814edf2 100644 --- a/src/fmri/visualisations.cpp +++ b/src/fmri/visualisations.cpp @@ -221,6 +221,24 @@ static Animation *getNormalizingAnimation(const fmri::LayerData &prevState, cons } } +static Animation *getSoftmaxAnimation(const fmri::LayerData &curState, const vector &prevPositions, + const vector &curPositions) +{ + CHECK_EQ(curState.shape().size(), 2) << "Softmax only supported for flat layers."; + + std::vector intensities(curState.data(), curState.data() + curState.numEntries()); + rescale(intensities.begin(), intensities.end(), 0, 1); + + EntryList entries; + for (auto i = 0u; i < intensities.size(); ++i) { + entries.emplace_back(intensities[i], make_pair(i, i)); + } + + return new ActivityAnimation(entries, prevPositions.data(), curPositions.data(), [](auto i) -> Color { + return {1 - i, 1 - i, 1}; + }); +} + Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState, const fmri::LayerInfo &layer, const vector &prevPositions, const vector &curPositions) @@ -248,6 +266,9 @@ Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const f case LayerInfo::Type::LRN: return getNormalizingAnimation(prevState, curState, prevPositions, curPositions); + case LayerInfo::Type::Softmax: + return getSoftmaxAnimation(curState, prevPositions, curPositions); + default: return nullptr; }