Add a new option to add a labels file and a means file.

This commit is contained in:
2017-10-09 13:00:26 +02:00
parent 1e901507fa
commit e70c5f22ce
5 changed files with 92 additions and 34 deletions

View File

@@ -10,14 +10,14 @@ using namespace std;
static void show_help(const char *progname, int exitcode) { static void show_help(const char *progname, int exitcode) {
cerr << "Usage: " << progname << " -m MODEL -w WEIGHTS INPUTS..." << endl cerr << "Usage: " << progname << " -m MODEL -w WEIGHTS INPUTS..." << endl
<< endl << endl
<< R"END( << R"END(Simulate the specified network on the specified inputs.
Simulate the specified network on the specified inputs.
Options: Options:
-h show this message -h show this message
-m (required) the model file to simulate -n (required) the model file to simulate
-w (required) the trained weights -w (required) the trained weights
)END" << endl; -m means file. Will be substracted from input if available.
-l labels file. Will be used to print prediction labels if available.)END" << endl;
exit(exitcode); exit(exitcode);
} }
@@ -32,10 +32,12 @@ static void check_file(const char *filename) {
Options Options::parse(const int argc, char *const argv[]) { Options Options::parse(const int argc, char *const argv[]) {
string model; string model;
string weights; string weights;
string means;
string labels;
char c; char c;
while ((c = getopt(argc, argv, "hm:w:")) != -1) { while ((c = getopt(argc, argv, "hm:w:n:l:")) != -1) {
switch (c) { switch (c) {
case 'h': case 'h':
show_help(argv[0], 0); show_help(argv[0], 0);
@@ -46,16 +48,27 @@ Options Options::parse(const int argc, char *const argv[]) {
weights = optarg; weights = optarg;
break; break;
case 'm': case 'n':
check_file(optarg); check_file(optarg);
model = optarg; model = optarg;
break; break;
case 'm':
check_file(optarg);
means = optarg;
break;
case 'l':
check_file(optarg);
labels = optarg;
break;
case '?': case '?':
show_help(argv[0], 1); show_help(argv[0], 1);
break; break;
default: default:
cerr << "Unhandled option: " << c << endl;
abort(); abort();
} }
} }
@@ -78,23 +91,36 @@ Options Options::parse(const int argc, char *const argv[]) {
show_help(argv[0], 1); show_help(argv[0], 1);
} }
return Options(move(model), move(weights), move(inputs)); return Options(move(model), move(weights), move(means), move(labels), move(inputs));
} }
Options::Options(string &&model, string &&weights, vector<string> &&inputs) noexcept: Options::Options(string &&model, string &&weights, string&& means, string&& labels, vector<string> &&inputs) noexcept:
modelPath(move(model)), modelPath(move(model)),
weightsPath(move(weights)), weightsPath(move(weights)),
inputPaths(move(inputs)) { meansPath(means),
labelsPath(labels),
inputPaths(move(inputs))
{
} }
const string &Options::model() const { const string& Options::model() const {
return modelPath; return modelPath;
} }
const string &Options::weights() const { const string& Options::weights() const {
return weightsPath; return weightsPath;
} }
const vector<string> &Options::inputs() const { const vector<string>& Options::inputs() const {
return inputPaths; return inputPaths;
} }
const string& Options::means() const
{
return meansPath;
}
const string& Options::labels() const
{
return labelsPath;
}

View File

@@ -12,17 +12,20 @@ namespace fmri {
public: public:
static Options parse(const int argc, char *const argv[]); static Options parse(const int argc, char *const argv[]);
const string &model() const; const string& model() const;
const string& weights() const;
const string& means() const;
const string& labels() const;
const string &weights() const; const vector<string>& inputs() const;
const vector<string> &inputs() const;
private: private:
const string modelPath; const string modelPath;
const string weightsPath; const string weightsPath;
const string meansPath;
const string labelsPath;
const vector<string> inputPaths; const vector<string> inputPaths;
Options(string &&, string &&, vector<string> &&) noexcept; Options(string &&, string &&, string&&, string&&, vector<string> &&) noexcept;
}; };
} }

View File

@@ -1,5 +1,5 @@
#include <cassert>
#include <iostream> #include <iostream>
#include <iterator>
#include <vector> #include <vector>
#include "Simulator.hpp" #include "Simulator.hpp"
@@ -8,7 +8,7 @@ using namespace caffe;
using namespace std; using namespace std;
using namespace fmri; using namespace fmri;
Simulator::Simulator(const string& model_file, const string& weights_file) : Simulator::Simulator(const string& model_file, const string& weights_file, const string& means_file) :
net(model_file, TEST) net(model_file, TEST)
{ {
net.CopyTrainedLayersFrom(weights_file); net.CopyTrainedLayersFrom(weights_file);
@@ -21,16 +21,18 @@ Simulator::Simulator(const string& model_file, const string& weights_file) :
input_geometry.height, input_geometry.width); input_geometry.height, input_geometry.width);
/* Forward dimension change to all layers. */ /* Forward dimension change to all layers. */
net.Reshape(); net.Reshape();
if (means_file != "") {
means = processMeans(means_file);
}
} }
void Simulator::simulate(const string& image_file) vector<Simulator::DType> Simulator::simulate(const string& image_file)
{ {
cv::Mat im = cv::imread(image_file, -1); cv::Mat im = cv::imread(image_file, -1);
if (im.empty()) { assert(!im.empty());
cerr << "Unable to read " << image_file << endl;
return;
}
auto input = preprocess(im); auto input = preprocess(im);
auto channels = getWrappedInputLayer(); auto channels = getWrappedInputLayer();
@@ -44,10 +46,7 @@ void Simulator::simulate(const string& image_file)
const DType *end = begin + output_layer->channels(); const DType *end = begin + output_layer->channels();
vector<DType> result(begin, end); vector<DType> result(begin, end);
// TODO: visualize, rather than just print. return result;
for (auto v : result) {
cout << v << endl;
}
} }
vector<cv::Mat> Simulator::getWrappedInputLayer() vector<cv::Mat> Simulator::getWrappedInputLayer()
@@ -111,8 +110,36 @@ cv::Mat Simulator::preprocess(cv::Mat original) const
cv::Mat sample_float; cv::Mat sample_float;
resized.convertTo(sample_float, num_channels == 3 ? CV_32FC3 : CV_32FC1); resized.convertTo(sample_float, num_channels == 3 ? CV_32FC3 : CV_32FC1);
// TODO: substract means. if (means.empty()) {
// Don't know if necessary yet.
return sample_float; return sample_float;
}
cv::Mat normalized;
cv::subtract(sample_float, means, normalized);
return normalized;
}
cv::Mat Simulator::processMeans(const string &means_file) const
{
BlobProto proto;
ReadProtoFromBinaryFileOrDie(means_file, &proto);
Blob<DType> mean_blob;
mean_blob.FromProto(proto);
assert(mean_blob.channels() == num_channels);
vector<cv::Mat> channels;
float* data = mean_blob.mutable_cpu_data();
for (unsigned int i = 0; i < num_channels; ++i) {
channels.emplace_back(mean_blob.height(), mean_blob.width(), CV_32FC1, data);
data += mean_blob.height() * mean_blob.width();
}
cv::Mat mean;
cv::merge(channels, mean);
return cv::Mat(input_geometry, mean.type(), cv::mean(mean));
} }

View File

@@ -17,16 +17,18 @@ namespace fmri {
public: public:
typedef float DType; typedef float DType;
Simulator(const string &model_file, const string &weights_file); Simulator(const string &model_file, const string &weights_file, const string &means_file = "");
void simulate(const string &input_file); vector<DType> simulate(const string &input_file);
private: private:
caffe::Net<DType> net; caffe::Net<DType> net;
cv::Size input_geometry; cv::Size input_geometry;
cv::Mat means;
unsigned int num_channels; unsigned int num_channels;
vector<cv::Mat> getWrappedInputLayer(); vector<cv::Mat> getWrappedInputLayer();
cv::Mat preprocess(cv::Mat original) const; cv::Mat preprocess(cv::Mat original) const;
cv::Mat processMeans(const string &means_file) const;
}; };
} }

View File

@@ -10,7 +10,7 @@ int main(int argc, char *const argv[]) {
Options options = Options::parse(argc, argv); Options options = Options::parse(argc, argv);
Simulator simulator(options.model(), options.weights()); Simulator simulator(options.model(), options.weights(), options.means());
for (const auto &image : options.inputs()) { for (const auto &image : options.inputs()) {
simulator.simulate(image); simulator.simulate(image);