Remove redundant code.

This commit is contained in:
2018-03-13 14:28:00 +01:00
parent e21c7b33e8
commit 93d4cb4df0
6 changed files with 69 additions and 68 deletions

View File

@@ -20,27 +20,8 @@ MultiImageVisualisation::MultiImageVisualisation(const fmri::LayerData &layer)
texture = loadTexture(layer.data(), width, channels * height, channels);
initNodePositions<Ordering::SQUARE>(channels, 3);
vertexBuffer = std::make_unique<float[]>(channels * BASE_VERTICES.size());
texCoordBuffer = std::make_unique<float[]>(channels * 2u * BASE_VERTICES.size() / 3);
auto v = 0;
for (auto i : Range(channels)) {
const auto& nodePos = &nodePositions_[3 * i];
for (auto j : Range(BASE_VERTICES.size())) {
vertexBuffer[v++] = nodePos[j % 3] + BASE_VERTICES[j];
}
const float textureCoords[] = {
1, (i + 1) / (float) channels,
1, i / (float) channels,
0, i / (float) channels,
0, (i + 1) / (float) channels,
};
memcpy(texCoordBuffer.get() + 8 * i, textureCoords, sizeof(textureCoords));
}
vertexBuffer = getVertices(nodePositions_);
texCoordBuffer = getTexCoords(channels);
}
void MultiImageVisualisation::render()
@@ -50,10 +31,47 @@ void MultiImageVisualisation::render()
glEnable(GL_TEXTURE_2D);
glTexEnvf(GL_TEXTURE_ENV, GL_TEXTURE_ENV_MODE, GL_REPLACE);
texture.bind(GL_TEXTURE_2D);
glTexCoordPointer(2, GL_FLOAT, 0, texCoordBuffer.get());
glVertexPointer(3, GL_FLOAT, 0, vertexBuffer.get());
glDrawArrays(GL_QUADS, 0, nodePositions_.size() / 3 * 4);
glTexCoordPointer(2, GL_FLOAT, 0, texCoordBuffer.data());
glVertexPointer(3, GL_FLOAT, 0, vertexBuffer.data());
glDrawArrays(GL_QUADS, 0, vertexBuffer.size() / 3);
glDisable(GL_TEXTURE_2D);
glDisableClientState(GL_VERTEX_ARRAY);
glDisableClientState(GL_TEXTURE_COORD_ARRAY);
}
vector<float> MultiImageVisualisation::getVertices(const std::vector<float> &nodePositions, float scaling)
{
std::vector<float> vertices;
vertices.reserve(nodePositions.size() * BASE_VERTICES.size() / 3);
for (auto i = 0u; i < nodePositions.size(); i += 3) {
auto pos = &nodePositions[i];
for (auto j = 0u; j < BASE_VERTICES.size(); ++j) {
vertices.push_back(BASE_VERTICES[j] * scaling + pos[j % 3]);
}
}
return vertices;
}
std::vector<float> MultiImageVisualisation::getTexCoords(int n)
{
std::vector<float> coords;
coords.reserve(8 * n);
const float channels = n;
for (int i = 0; i < n; ++i) {
std::array<float, 8> textureCoords = {
1, (i + 1) / channels,
1, i / channels,
0, i / channels,
0, (i + 1) / channels,
};
for (auto coord : textureCoords) {
coords.push_back(coord);
}
}
return coords;
}

View File

@@ -22,9 +22,12 @@ namespace fmri
void render() override;
static vector<float> getVertices(const std::vector<float> &nodePositions, float scaling = 1);
static std::vector<float> getTexCoords(int n);
private:
Texture texture;
std::unique_ptr<float[]> vertexBuffer;
std::unique_ptr<float[]> texCoordBuffer;
std::vector<float> vertexBuffer;
std::vector<float> texCoordBuffer;
};
}

View File

@@ -14,31 +14,14 @@ PoolingLayerAnimation::PoolingLayerAnimation(const LayerData &prevData, const La
const std::vector<float> &curPositions, float xDist) :
original(loadTextureForData(prevData)),
downSampled(loadTextureForData(curData)),
startingPositions(computePositions(prevPositions)),
deltas(startingPositions.size())
startingPositions(MultiImageVisualisation::getVertices(prevPositions)),
deltas(startingPositions.size()),
textureCoordinates(MultiImageVisualisation::getTexCoords(prevPositions.size() / 3))
{
CHECK_EQ(prevPositions.size(), curPositions.size()) << "Layers should be same size. Caffe error?";
const float channels = curData.shape()[1];
textureCoordinates.reserve(curPositions.size() / 3 * 4);
const auto downScaling = sqrt(
static_cast<float>(curData.shape()[2] * curData.shape()[3]) / (prevData.shape()[2] * prevData.shape()[3]));
for (auto i : Range(prevPositions.size() / 3)) {
const array<float, 8> nodeTexCoords = {
1, (i + 1) / channels,
1, i / channels,
0, i / channels,
0, (i + 1) / channels,
};
for (auto coord : nodeTexCoords) {
textureCoordinates.push_back(coord);
}
}
const auto targetPositions = computePositions(curPositions, downScaling);
const auto targetPositions = MultiImageVisualisation::getVertices(curPositions, downScaling);
caffe::caffe_sub(targetPositions.size(), targetPositions.data(), startingPositions.data(), deltas.data());
for (auto i = 0u; i < deltas.size(); i+=3) {
@@ -73,18 +56,3 @@ Texture PoolingLayerAnimation::loadTextureForData(const LayerData &data)
auto channels = data.shape()[1], width = data.shape()[2], height = data.shape()[3];
return loadTexture(data.data(), width, height * channels, channels);
}
vector<float> PoolingLayerAnimation::computePositions(const vector<float> &nodePositions, float scaling)
{
vector<float> positions;
positions.reserve(4 * nodePositions.size());
for (auto i : Range(nodePositions.size() / 3)) {
const float *pos = &nodePositions[3 * i];
for (auto j : Range(MultiImageVisualisation::BASE_VERTICES.size())) {
positions.push_back(pos[j % 3] + MultiImageVisualisation::BASE_VERTICES[j] * scaling);
}
}
return positions;
}

View File

@@ -21,6 +21,5 @@ namespace fmri
std::vector<float> textureCoordinates;
static Texture loadTextureForData(const LayerData& data);
static std::vector<float> computePositions(const std::vector<float> &nodePositions, float scaling = 1);
};
}

View File

@@ -207,4 +207,21 @@ namespace fmri
return indices;
}
/**
* Fix non-normal floating point values in a range.
*
* @tparam It
* @param first Start of range iterator
* @param last Past the end of range iterator
* @param normalValue Value to assign to non-normal values. Default 1.
*/
template<class It>
inline void normalize(It first, It last, typename std::iterator_traits<It>::value_type normalValue = 1)
{
for (; first != last; ++first) {
if (!std::isnormal(*first)) {
*first = normalValue;
}
}
}
}

View File

@@ -184,11 +184,7 @@ static Animation *getNormalizingAnimation(const fmri::LayerData &prevState, cons
caffe::caffe_div(scaling.size(), prevState.data(), curState.data(), scaling.data());
// Fix divisions by zero. For those cases, pick 1 since it doesn't matter anyway.
for (auto &s : scaling) {
if (!isnormal(s)) {
s = 1;
}
}
normalize(scaling.begin(), scaling.end());
if (prevState.shape().size() == 2) {
EntryList entries;