From e70c5f22ce0648d2f795a4b633c09d1d7595b3ac Mon Sep 17 00:00:00 2001 From: Bert Peters Date: Mon, 9 Oct 2017 13:00:26 +0200 Subject: [PATCH] Add a new option to add a labels file and a means file. --- src/Options.cpp | 50 +++++++++++++++++++++++++++++++----------- src/Options.hpp | 13 ++++++----- src/Simulator.cpp | 55 +++++++++++++++++++++++++++++++++++------------ src/Simulator.hpp | 6 ++++-- src/main.cpp | 2 +- 5 files changed, 92 insertions(+), 34 deletions(-) diff --git a/src/Options.cpp b/src/Options.cpp index 5473ad3..1b17b1c 100644 --- a/src/Options.cpp +++ b/src/Options.cpp @@ -10,14 +10,14 @@ using namespace std; static void show_help(const char *progname, int exitcode) { cerr << "Usage: " << progname << " -m MODEL -w WEIGHTS INPUTS..." << endl << endl - << R"END( -Simulate the specified network on the specified inputs. + << R"END(Simulate the specified network on the specified inputs. Options: -h show this message - -m (required) the model file to simulate + -n (required) the model file to simulate -w (required) the trained weights -)END" << endl; + -m means file. Will be substracted from input if available. + -l labels file. Will be used to print prediction labels if available.)END" << endl; exit(exitcode); } @@ -32,10 +32,12 @@ static void check_file(const char *filename) { Options Options::parse(const int argc, char *const argv[]) { string model; string weights; + string means; + string labels; char c; - while ((c = getopt(argc, argv, "hm:w:")) != -1) { + while ((c = getopt(argc, argv, "hm:w:n:l:")) != -1) { switch (c) { case 'h': show_help(argv[0], 0); @@ -46,16 +48,27 @@ Options Options::parse(const int argc, char *const argv[]) { weights = optarg; break; - case 'm': + case 'n': check_file(optarg); model = optarg; break; + case 'm': + check_file(optarg); + means = optarg; + break; + + case 'l': + check_file(optarg); + labels = optarg; + break; + case '?': show_help(argv[0], 1); break; default: + cerr << "Unhandled option: " << c << endl; abort(); } } @@ -78,23 +91,36 @@ Options Options::parse(const int argc, char *const argv[]) { show_help(argv[0], 1); } - return Options(move(model), move(weights), move(inputs)); + return Options(move(model), move(weights), move(means), move(labels), move(inputs)); } -Options::Options(string &&model, string &&weights, vector &&inputs) noexcept: +Options::Options(string &&model, string &&weights, string&& means, string&& labels, vector &&inputs) noexcept: modelPath(move(model)), weightsPath(move(weights)), - inputPaths(move(inputs)) { + meansPath(means), + labelsPath(labels), + inputPaths(move(inputs)) +{ } -const string &Options::model() const { +const string& Options::model() const { return modelPath; } -const string &Options::weights() const { +const string& Options::weights() const { return weightsPath; } -const vector &Options::inputs() const { +const vector& Options::inputs() const { return inputPaths; } + +const string& Options::means() const +{ + return meansPath; +} + +const string& Options::labels() const +{ + return labelsPath; +} diff --git a/src/Options.hpp b/src/Options.hpp index 0c2b2ac..b5d672b 100644 --- a/src/Options.hpp +++ b/src/Options.hpp @@ -12,17 +12,20 @@ namespace fmri { public: static Options parse(const int argc, char *const argv[]); - const string &model() const; + const string& model() const; + const string& weights() const; + const string& means() const; + const string& labels() const; - const string &weights() const; - - const vector &inputs() const; + const vector& inputs() const; private: const string modelPath; const string weightsPath; + const string meansPath; + const string labelsPath; const vector inputPaths; - Options(string &&, string &&, vector &&) noexcept; + Options(string &&, string &&, string&&, string&&, vector &&) noexcept; }; } diff --git a/src/Simulator.cpp b/src/Simulator.cpp index cc7a3cf..ac72505 100644 --- a/src/Simulator.cpp +++ b/src/Simulator.cpp @@ -1,5 +1,5 @@ +#include #include -#include #include #include "Simulator.hpp" @@ -8,7 +8,7 @@ using namespace caffe; using namespace std; using namespace fmri; -Simulator::Simulator(const string& model_file, const string& weights_file) : +Simulator::Simulator(const string& model_file, const string& weights_file, const string& means_file) : net(model_file, TEST) { net.CopyTrainedLayersFrom(weights_file); @@ -21,16 +21,18 @@ Simulator::Simulator(const string& model_file, const string& weights_file) : input_geometry.height, input_geometry.width); /* Forward dimension change to all layers. */ net.Reshape(); + + if (means_file != "") { + means = processMeans(means_file); + } + } -void Simulator::simulate(const string& image_file) +vector Simulator::simulate(const string& image_file) { cv::Mat im = cv::imread(image_file, -1); - if (im.empty()) { - cerr << "Unable to read " << image_file << endl; - return; - } + assert(!im.empty()); auto input = preprocess(im); auto channels = getWrappedInputLayer(); @@ -44,10 +46,7 @@ void Simulator::simulate(const string& image_file) const DType *end = begin + output_layer->channels(); vector result(begin, end); - // TODO: visualize, rather than just print. - for (auto v : result) { - cout << v << endl; - } + return result; } vector Simulator::getWrappedInputLayer() @@ -111,8 +110,36 @@ cv::Mat Simulator::preprocess(cv::Mat original) const cv::Mat sample_float; resized.convertTo(sample_float, num_channels == 3 ? CV_32FC3 : CV_32FC1); - // TODO: substract means. - // Don't know if necessary yet. + if (means.empty()) { + return sample_float; + } + + cv::Mat normalized; + cv::subtract(sample_float, means, normalized); + + return normalized; - return sample_float; +} + +cv::Mat Simulator::processMeans(const string &means_file) const +{ + BlobProto proto; + ReadProtoFromBinaryFileOrDie(means_file, &proto); + + Blob mean_blob; + mean_blob.FromProto(proto); + + assert(mean_blob.channels() == num_channels); + + vector channels; + float* data = mean_blob.mutable_cpu_data(); + for (unsigned int i = 0; i < num_channels; ++i) { + channels.emplace_back(mean_blob.height(), mean_blob.width(), CV_32FC1, data); + data += mean_blob.height() * mean_blob.width(); + } + + cv::Mat mean; + cv::merge(channels, mean); + + return cv::Mat(input_geometry, mean.type(), cv::mean(mean)); } diff --git a/src/Simulator.hpp b/src/Simulator.hpp index bfb4838..d969fb8 100644 --- a/src/Simulator.hpp +++ b/src/Simulator.hpp @@ -17,16 +17,18 @@ namespace fmri { public: typedef float DType; - Simulator(const string &model_file, const string &weights_file); + Simulator(const string &model_file, const string &weights_file, const string &means_file = ""); - void simulate(const string &input_file); + vector simulate(const string &input_file); private: caffe::Net net; cv::Size input_geometry; + cv::Mat means; unsigned int num_channels; vector getWrappedInputLayer(); cv::Mat preprocess(cv::Mat original) const; + cv::Mat processMeans(const string &means_file) const; }; } diff --git a/src/main.cpp b/src/main.cpp index 6682384..3443802 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -10,7 +10,7 @@ int main(int argc, char *const argv[]) { Options options = Options::parse(argc, argv); - Simulator simulator(options.model(), options.weights()); + Simulator simulator(options.model(), options.weights(), options.means()); for (const auto &image : options.inputs()) { simulator.simulate(image);