Clean up on layer visualisations.

This commit is contained in:
2018-02-09 11:55:48 +01:00
parent cf5dd28fef
commit a84e4e80a2
6 changed files with 47 additions and 18 deletions

View File

@@ -0,0 +1,16 @@
#pragma once
#include "LayerVisualisation.hpp"
namespace fmri
{
/**
* Visualisation that does not actually do anything.
*/
class DummyLayerVisualisation : public LayerVisualisation
{
public:
void render() override
{};
};
}

View File

@@ -14,6 +14,8 @@ LayerInfo::Type LayerInfo::typeByName(string_view name)
return Type::ReLU; return Type::ReLU;
} else if (name == "Pooling") { } else if (name == "Pooling") {
return Type::Pooling; return Type::Pooling;
} else if (name == "InnerProduct") {
return Type::InnerProduct;
} else { } else {
LOG(INFO) << "Received unknown layer type: " << name << endl; LOG(INFO) << "Received unknown layer type: " << name << endl;
return Type::Other; return Type::Other;
@@ -24,7 +26,6 @@ LayerInfo::LayerInfo(string_view name, string_view type,
const vector<boost::shared_ptr<caffe::Blob<DType>>> &parameters) const vector<boost::shared_ptr<caffe::Blob<DType>>> &parameters)
: parameters_(parameters), type_(typeByName(type)), name_(name) : parameters_(parameters), type_(typeByName(type)), name_(name)
{ {
} }
const std::string &LayerInfo::name() const const std::string &LayerInfo::name() const

View File

@@ -17,6 +17,7 @@ namespace fmri
ReLU, ReLU,
Pooling, Pooling,
Output, Output,
InnerProduct,
Other Other
}; };

View File

@@ -1,4 +1,4 @@
#include <GL/glut.h> #include <GL/freeglut.h>
#include <cmath> #include <cmath>
#include <sstream> #include <sstream>
#include <iostream> #include <iostream>
@@ -60,7 +60,8 @@ static void handleKeys(unsigned char key, int, int)
case 'q': case 'q':
// Utility quit function. // Utility quit function.
exit(0); glutLeaveMainLoop();
break;
case 'h': case 'h':
camera.reset(); camera.reset();

View File

@@ -16,6 +16,6 @@ namespace fmri
static Camera& instance(); static Camera& instance();
private: private:
Camera() = default; Camera() noexcept = default;
}; };
} }

View File

@@ -14,6 +14,8 @@
#include "FlatLayerVisualisation.hpp" #include "FlatLayerVisualisation.hpp"
#include "MultiImageVisualisation.hpp" #include "MultiImageVisualisation.hpp"
#include "Range.hpp" #include "Range.hpp"
#include "ActivityAnimation.hpp"
#include "DummyLayerVisualisation.hpp"
using namespace std; using namespace std;
using namespace fmri; using namespace fmri;
@@ -25,6 +27,7 @@ struct
vector<vector<LayerData>> data; vector<vector<LayerData>> data;
vector<vector<LayerData>>::iterator currentData; vector<vector<LayerData>>::iterator currentData;
vector<unique_ptr<LayerVisualisation>> layerVisualisations; vector<unique_ptr<LayerVisualisation>> layerVisualisations;
vector<unique_ptr<ActivityAnimation>> animations;
} rendererData; } rendererData;
static void loadSimulationData(const Options &options) static void loadSimulationData(const Options &options)
@@ -84,6 +87,8 @@ static void render()
glPopMatrix(); glPopMatrix();
glTranslatef(-10, 0, 0); glTranslatef(-10, 0, 0);
} }
glPopMatrix(); glPopMatrix();
glutSwapBuffers(); glutSwapBuffers();
@@ -98,24 +103,15 @@ static void renderLayerName(const LayerData &data)
glTranslatef(0, 0, -10); glTranslatef(0, 0, -10);
} }
static LayerVisualisation *getVisualisationForLayer(const LayerData &layer);
static void updateVisualisers() static void updateVisualisers()
{ {
rendererData.layerVisualisations.clear(); rendererData.layerVisualisations.clear();
rendererData.animations.clear();
for (auto &layer : *rendererData.currentData) { for (LayerData &layer : *rendererData.currentData) {
LayerVisualisation* visualisation = nullptr; LayerVisualisation* visualisation = getVisualisationForLayer(layer);
switch (layer.shape().size()) {
case 2:
visualisation = new FlatLayerVisualisation(layer, FlatLayerVisualisation::Ordering::SQUARE);
break;
case 4:
visualisation = new MultiImageVisualisation(layer);
break;
default:
abort();
}
rendererData.layerVisualisations.emplace_back(visualisation); rendererData.layerVisualisations.emplace_back(visualisation);
} }
@@ -123,6 +119,20 @@ static void updateVisualisers()
glutPostRedisplay(); glutPostRedisplay();
} }
LayerVisualisation *getVisualisationForLayer(const LayerData &layer)
{
switch (layer.shape().size()) {
case 2:
return new FlatLayerVisualisation(layer, FlatLayerVisualisation::Ordering::SQUARE);
case 4:
return new MultiImageVisualisation(layer);
default:
return new DummyLayerVisualisation();
}
}
static void specialKeyFunc(int key, int, int) static void specialKeyFunc(int key, int, int)
{ {
switch (key) { switch (key) {