From cf08dd476167aa60ba1df970d820196c34cee7a2 Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Sat, 7 Oct 2017 21:35:59 +0200 Subject: [PATCH] Actually run the network. --- src/Simulator.cpp | 113 ++++++++++++++++++++++++++++++++++++++++++++-- src/Simulator.hpp | 10 ++++ src/main.cpp | 2 + 3 files changed, 120 insertions(+), 5 deletions(-) diff --git a/src/Simulator.cpp b/src/Simulator.cpp index 4919ea2..cc7a3cf 100644 --- a/src/Simulator.cpp +++ b/src/Simulator.cpp @@ -1,15 +1,118 @@ #include +#include +#include + #include "Simulator.hpp" using namespace caffe; using namespace std; using namespace fmri; -Simulator::Simulator(const string &model_file, const string &weights_file) : - net(model_file, TEST) { - net.CopyTrainedLayersFrom(weights_file); +Simulator::Simulator(const string& model_file, const string& weights_file) : + net(model_file, TEST) +{ + net.CopyTrainedLayersFrom(weights_file); + + Blob* input_layer = net.input_blobs()[0]; + input_geometry = cv::Size(input_layer->width(), input_layer->height()); + num_channels = input_layer->channels(); + + input_layer->Reshape(1, num_channels, + input_geometry.height, input_geometry.width); + /* Forward dimension change to all layers. */ + net.Reshape(); } -void Simulator::simulate(const string &image_file) { - cerr << "This is not implemented yet." << endl; +void Simulator::simulate(const string& image_file) +{ + cv::Mat im = cv::imread(image_file, -1); + + if (im.empty()) { + cerr << "Unable to read " << image_file << endl; + return; + } + + auto input = preprocess(im); + auto channels = getWrappedInputLayer(); + + cv::split(input, channels); + + net.Forward(); + + Blob *output_layer = net.output_blobs()[0]; + const DType *begin = output_layer->cpu_data(); + const DType *end = begin + output_layer->channels(); + vector result(begin, end); + + // TODO: visualize, rather than just print. + for (auto v : result) { + cout << v << endl; + } +} + +vector Simulator::getWrappedInputLayer() +{ + vector channels; + Blob* input_layer = net.input_blobs()[0]; + + const int width = input_geometry.width; + const int height = input_geometry.height; + + DType* input_data = input_layer->mutable_cpu_data(); + for (unsigned int i = 0; i < num_channels; i++) { + channels.emplace_back(height, width, CV_32FC1, input_data); + input_data += width * height; + } + + return channels; +} + +static cv::Mat fix_channels(const int num_channels, cv::Mat original) { + if (num_channels == original.channels()) { + return original; + } + + cv::Mat converted; + + if (num_channels == 1 && original.channels() == 3) { + cv::cvtColor(original, converted, cv::COLOR_BGR2GRAY); + } else if (num_channels == 1 && original.channels() == 4) { + cv::cvtColor(original, converted, cv::COLOR_BGRA2GRAY); + } else if (num_channels == 3 && original.channels() == 1) { + cv::cvtColor(original, converted, cv::COLOR_GRAY2BGR); + } else if (num_channels == 3 && original.channels() == 4) { + cv::cvtColor(original, converted, cv::COLOR_BGRA2BGR); + } else { + // Don't know how to convert. + abort(); + } + + return converted; +} + +static cv::Mat resize(const cv::Size& targetSize, cv::Mat original) +{ + if (targetSize != original.size()) { + cv::Mat resized; + cv::resize(original, resized, targetSize); + + return resized; + } + + return original; +} + +cv::Mat Simulator::preprocess(cv::Mat original) const +{ + auto converted = fix_channels(num_channels, original); + + auto resized = resize(input_geometry, converted); + + cv::Mat sample_float; + resized.convertTo(sample_float, num_channels == 3 ? CV_32FC3 : CV_32FC1); + + // TODO: substract means. + // Don't know if necessary yet. + + return sample_float; } diff --git a/src/Simulator.hpp b/src/Simulator.hpp index 262b450..bfb4838 100644 --- a/src/Simulator.hpp +++ b/src/Simulator.hpp @@ -2,11 +2,16 @@ #include #include +#include #include +#include +#include +#include namespace fmri { using std::string; + using std::vector; class Simulator { public: @@ -18,5 +23,10 @@ namespace fmri { private: caffe::Net net; + cv::Size input_geometry; + unsigned int num_channels; + + vector getWrappedInputLayer(); + cv::Mat preprocess(cv::Mat original) const; }; } diff --git a/src/main.cpp b/src/main.cpp index 15487d8..6682384 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -16,5 +16,7 @@ int main(int argc, char *const argv[]) { simulator.simulate(image); } + ::google::ShutdownGoogleLogging(); + return 0; }