Implement visualisation of dropout layers.
This commit is contained in:
@@ -16,6 +16,8 @@ LayerInfo::Type LayerInfo::typeByName(string_view name)
|
|||||||
return Type::Pooling;
|
return Type::Pooling;
|
||||||
} else if (name == "InnerProduct") {
|
} else if (name == "InnerProduct") {
|
||||||
return Type::InnerProduct;
|
return Type::InnerProduct;
|
||||||
|
} else if (name == "Dropout") {
|
||||||
|
return Type::DropOut;
|
||||||
} else {
|
} else {
|
||||||
LOG(INFO) << "Received unknown layer type: " << name << endl;
|
LOG(INFO) << "Received unknown layer type: " << name << endl;
|
||||||
return Type::Other;
|
return Type::Other;
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ namespace fmri
|
|||||||
Pooling,
|
Pooling,
|
||||||
Output,
|
Output,
|
||||||
InnerProduct,
|
InnerProduct,
|
||||||
|
DropOut,
|
||||||
Other
|
Other
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,8 @@ using namespace std;
|
|||||||
// Maximum number of interactions shown
|
// Maximum number of interactions shown
|
||||||
static constexpr size_t INTERACTION_LIMIT = 10000;
|
static constexpr size_t INTERACTION_LIMIT = 10000;
|
||||||
|
|
||||||
|
typedef vector<pair<float, pair<size_t, size_t>>> EntryList;
|
||||||
|
|
||||||
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
|
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &layer)
|
||||||
{
|
{
|
||||||
switch (layer.shape().size()) {
|
switch (layer.shape().size()) {
|
||||||
@@ -32,8 +34,6 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
|
|||||||
{
|
{
|
||||||
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
|
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
|
||||||
|
|
||||||
typedef pair<DType, pair<size_t, size_t>> Entry;
|
|
||||||
|
|
||||||
auto data = prevState.data();
|
auto data = prevState.data();
|
||||||
|
|
||||||
CHECK_GE(layer.parameters().size(), 1) << "Layer should have correct parameters";
|
CHECK_GE(layer.parameters().size(), 1) << "Layer should have correct parameters";
|
||||||
@@ -53,7 +53,7 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
|
|||||||
return abs(a) > abs(b);
|
return abs(a) > abs(b);
|
||||||
});
|
});
|
||||||
|
|
||||||
vector<Entry> result;
|
EntryList result;
|
||||||
result.reserve(desiredSize);
|
result.reserve(desiredSize);
|
||||||
for (auto i : idx) {
|
for (auto i : idx) {
|
||||||
result.emplace_back(interactions[i], make_pair(i / shape[0], i % shape[0]));
|
result.emplace_back(interactions[i], make_pair(i / shape[0], i % shape[0]));
|
||||||
@@ -62,12 +62,27 @@ static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, c
|
|||||||
return new ActivityAnimation(result, prevPositions.data(), curPositions.data(), -10);
|
return new ActivityAnimation(result, prevPositions.data(), curPositions.data(), -10);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
static Animation *getDropOutAnimation(const fmri::LayerData &curState,
|
||||||
|
const vector<float> &prevPositions,
|
||||||
|
const vector<float> &curPositions) {
|
||||||
|
auto data = curState.data();
|
||||||
|
EntryList results;
|
||||||
|
results.reserve(curState.numEntries());
|
||||||
|
for (auto i : Range(curState.numEntries())) {
|
||||||
|
if (data[i] != 0) {
|
||||||
|
results.emplace_back(data[i], make_pair(i, i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return new ActivityAnimation(results, prevPositions.data(), curPositions.data(), -10);
|
||||||
|
}
|
||||||
|
|
||||||
Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState,
|
Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const fmri::LayerData &curState,
|
||||||
const fmri::LayerInfo &layer, const vector<float> &prevPositions,
|
const fmri::LayerInfo &layer, const vector<float> &prevPositions,
|
||||||
const vector<float> &curPositions)
|
const vector<float> &curPositions)
|
||||||
{
|
{
|
||||||
if (prevPositions.empty() || curPositions.empty()) {
|
if (prevPositions.empty() || curPositions.empty()) {
|
||||||
// Not all positions know, no visualisation possible.
|
// Not all positions known, no visualisation possible.
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -77,6 +92,10 @@ Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const f
|
|||||||
return getFullyConnectedAnimation(prevState, layer,
|
return getFullyConnectedAnimation(prevState, layer,
|
||||||
prevPositions, curPositions);
|
prevPositions, curPositions);
|
||||||
|
|
||||||
|
case LayerInfo::Type::DropOut:
|
||||||
|
return getDropOutAnimation(curState, prevPositions, curPositions);
|
||||||
|
|
||||||
|
|
||||||
default:
|
default:
|
||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user