Add a new option to add a labels file and a means file.
This commit is contained in:
@@ -10,14 +10,14 @@ using namespace std;
|
|||||||
static void show_help(const char *progname, int exitcode) {
|
static void show_help(const char *progname, int exitcode) {
|
||||||
cerr << "Usage: " << progname << " -m MODEL -w WEIGHTS INPUTS..." << endl
|
cerr << "Usage: " << progname << " -m MODEL -w WEIGHTS INPUTS..." << endl
|
||||||
<< endl
|
<< endl
|
||||||
<< R"END(
|
<< R"END(Simulate the specified network on the specified inputs.
|
||||||
Simulate the specified network on the specified inputs.
|
|
||||||
|
|
||||||
Options:
|
Options:
|
||||||
-h show this message
|
-h show this message
|
||||||
-m (required) the model file to simulate
|
-n (required) the model file to simulate
|
||||||
-w (required) the trained weights
|
-w (required) the trained weights
|
||||||
)END" << endl;
|
-m means file. Will be substracted from input if available.
|
||||||
|
-l labels file. Will be used to print prediction labels if available.)END" << endl;
|
||||||
|
|
||||||
exit(exitcode);
|
exit(exitcode);
|
||||||
}
|
}
|
||||||
@@ -32,10 +32,12 @@ static void check_file(const char *filename) {
|
|||||||
Options Options::parse(const int argc, char *const argv[]) {
|
Options Options::parse(const int argc, char *const argv[]) {
|
||||||
string model;
|
string model;
|
||||||
string weights;
|
string weights;
|
||||||
|
string means;
|
||||||
|
string labels;
|
||||||
|
|
||||||
char c;
|
char c;
|
||||||
|
|
||||||
while ((c = getopt(argc, argv, "hm:w:")) != -1) {
|
while ((c = getopt(argc, argv, "hm:w:n:l:")) != -1) {
|
||||||
switch (c) {
|
switch (c) {
|
||||||
case 'h':
|
case 'h':
|
||||||
show_help(argv[0], 0);
|
show_help(argv[0], 0);
|
||||||
@@ -46,16 +48,27 @@ Options Options::parse(const int argc, char *const argv[]) {
|
|||||||
weights = optarg;
|
weights = optarg;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
case 'm':
|
case 'n':
|
||||||
check_file(optarg);
|
check_file(optarg);
|
||||||
model = optarg;
|
model = optarg;
|
||||||
break;
|
break;
|
||||||
|
|
||||||
|
case 'm':
|
||||||
|
check_file(optarg);
|
||||||
|
means = optarg;
|
||||||
|
break;
|
||||||
|
|
||||||
|
case 'l':
|
||||||
|
check_file(optarg);
|
||||||
|
labels = optarg;
|
||||||
|
break;
|
||||||
|
|
||||||
case '?':
|
case '?':
|
||||||
show_help(argv[0], 1);
|
show_help(argv[0], 1);
|
||||||
break;
|
break;
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
cerr << "Unhandled option: " << c << endl;
|
||||||
abort();
|
abort();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -78,23 +91,36 @@ Options Options::parse(const int argc, char *const argv[]) {
|
|||||||
show_help(argv[0], 1);
|
show_help(argv[0], 1);
|
||||||
}
|
}
|
||||||
|
|
||||||
return Options(move(model), move(weights), move(inputs));
|
return Options(move(model), move(weights), move(means), move(labels), move(inputs));
|
||||||
}
|
}
|
||||||
|
|
||||||
Options::Options(string &&model, string &&weights, vector<string> &&inputs) noexcept:
|
Options::Options(string &&model, string &&weights, string&& means, string&& labels, vector<string> &&inputs) noexcept:
|
||||||
modelPath(move(model)),
|
modelPath(move(model)),
|
||||||
weightsPath(move(weights)),
|
weightsPath(move(weights)),
|
||||||
inputPaths(move(inputs)) {
|
meansPath(means),
|
||||||
|
labelsPath(labels),
|
||||||
|
inputPaths(move(inputs))
|
||||||
|
{
|
||||||
}
|
}
|
||||||
|
|
||||||
const string &Options::model() const {
|
const string& Options::model() const {
|
||||||
return modelPath;
|
return modelPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
const string &Options::weights() const {
|
const string& Options::weights() const {
|
||||||
return weightsPath;
|
return weightsPath;
|
||||||
}
|
}
|
||||||
|
|
||||||
const vector<string> &Options::inputs() const {
|
const vector<string>& Options::inputs() const {
|
||||||
return inputPaths;
|
return inputPaths;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const string& Options::means() const
|
||||||
|
{
|
||||||
|
return meansPath;
|
||||||
|
}
|
||||||
|
|
||||||
|
const string& Options::labels() const
|
||||||
|
{
|
||||||
|
return labelsPath;
|
||||||
|
}
|
||||||
|
|||||||
@@ -12,17 +12,20 @@ namespace fmri {
|
|||||||
public:
|
public:
|
||||||
static Options parse(const int argc, char *const argv[]);
|
static Options parse(const int argc, char *const argv[]);
|
||||||
|
|
||||||
const string &model() const;
|
const string& model() const;
|
||||||
|
const string& weights() const;
|
||||||
|
const string& means() const;
|
||||||
|
const string& labels() const;
|
||||||
|
|
||||||
const string &weights() const;
|
const vector<string>& inputs() const;
|
||||||
|
|
||||||
const vector<string> &inputs() const;
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
const string modelPath;
|
const string modelPath;
|
||||||
const string weightsPath;
|
const string weightsPath;
|
||||||
|
const string meansPath;
|
||||||
|
const string labelsPath;
|
||||||
const vector<string> inputPaths;
|
const vector<string> inputPaths;
|
||||||
|
|
||||||
Options(string &&, string &&, vector<string> &&) noexcept;
|
Options(string &&, string &&, string&&, string&&, vector<string> &&) noexcept;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
|
#include <cassert>
|
||||||
#include <iostream>
|
#include <iostream>
|
||||||
#include <iterator>
|
|
||||||
#include <vector>
|
#include <vector>
|
||||||
|
|
||||||
#include "Simulator.hpp"
|
#include "Simulator.hpp"
|
||||||
@@ -8,7 +8,7 @@ using namespace caffe;
|
|||||||
using namespace std;
|
using namespace std;
|
||||||
using namespace fmri;
|
using namespace fmri;
|
||||||
|
|
||||||
Simulator::Simulator(const string& model_file, const string& weights_file) :
|
Simulator::Simulator(const string& model_file, const string& weights_file, const string& means_file) :
|
||||||
net(model_file, TEST)
|
net(model_file, TEST)
|
||||||
{
|
{
|
||||||
net.CopyTrainedLayersFrom(weights_file);
|
net.CopyTrainedLayersFrom(weights_file);
|
||||||
@@ -21,16 +21,18 @@ Simulator::Simulator(const string& model_file, const string& weights_file) :
|
|||||||
input_geometry.height, input_geometry.width);
|
input_geometry.height, input_geometry.width);
|
||||||
/* Forward dimension change to all layers. */
|
/* Forward dimension change to all layers. */
|
||||||
net.Reshape();
|
net.Reshape();
|
||||||
|
|
||||||
|
if (means_file != "") {
|
||||||
|
means = processMeans(means_file);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
void Simulator::simulate(const string& image_file)
|
vector<Simulator::DType> Simulator::simulate(const string& image_file)
|
||||||
{
|
{
|
||||||
cv::Mat im = cv::imread(image_file, -1);
|
cv::Mat im = cv::imread(image_file, -1);
|
||||||
|
|
||||||
if (im.empty()) {
|
assert(!im.empty());
|
||||||
cerr << "Unable to read " << image_file << endl;
|
|
||||||
return;
|
|
||||||
}
|
|
||||||
|
|
||||||
auto input = preprocess(im);
|
auto input = preprocess(im);
|
||||||
auto channels = getWrappedInputLayer();
|
auto channels = getWrappedInputLayer();
|
||||||
@@ -44,10 +46,7 @@ void Simulator::simulate(const string& image_file)
|
|||||||
const DType *end = begin + output_layer->channels();
|
const DType *end = begin + output_layer->channels();
|
||||||
vector<DType> result(begin, end);
|
vector<DType> result(begin, end);
|
||||||
|
|
||||||
// TODO: visualize, rather than just print.
|
return result;
|
||||||
for (auto v : result) {
|
|
||||||
cout << v << endl;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
vector<cv::Mat> Simulator::getWrappedInputLayer()
|
vector<cv::Mat> Simulator::getWrappedInputLayer()
|
||||||
@@ -111,8 +110,36 @@ cv::Mat Simulator::preprocess(cv::Mat original) const
|
|||||||
cv::Mat sample_float;
|
cv::Mat sample_float;
|
||||||
resized.convertTo(sample_float, num_channels == 3 ? CV_32FC3 : CV_32FC1);
|
resized.convertTo(sample_float, num_channels == 3 ? CV_32FC3 : CV_32FC1);
|
||||||
|
|
||||||
// TODO: substract means.
|
if (means.empty()) {
|
||||||
// Don't know if necessary yet.
|
|
||||||
|
|
||||||
return sample_float;
|
return sample_float;
|
||||||
|
}
|
||||||
|
|
||||||
|
cv::Mat normalized;
|
||||||
|
cv::subtract(sample_float, means, normalized);
|
||||||
|
|
||||||
|
return normalized;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
cv::Mat Simulator::processMeans(const string &means_file) const
|
||||||
|
{
|
||||||
|
BlobProto proto;
|
||||||
|
ReadProtoFromBinaryFileOrDie(means_file, &proto);
|
||||||
|
|
||||||
|
Blob<DType> mean_blob;
|
||||||
|
mean_blob.FromProto(proto);
|
||||||
|
|
||||||
|
assert(mean_blob.channels() == num_channels);
|
||||||
|
|
||||||
|
vector<cv::Mat> channels;
|
||||||
|
float* data = mean_blob.mutable_cpu_data();
|
||||||
|
for (unsigned int i = 0; i < num_channels; ++i) {
|
||||||
|
channels.emplace_back(mean_blob.height(), mean_blob.width(), CV_32FC1, data);
|
||||||
|
data += mean_blob.height() * mean_blob.width();
|
||||||
|
}
|
||||||
|
|
||||||
|
cv::Mat mean;
|
||||||
|
cv::merge(channels, mean);
|
||||||
|
|
||||||
|
return cv::Mat(input_geometry, mean.type(), cv::mean(mean));
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,16 +17,18 @@ namespace fmri {
|
|||||||
public:
|
public:
|
||||||
typedef float DType;
|
typedef float DType;
|
||||||
|
|
||||||
Simulator(const string &model_file, const string &weights_file);
|
Simulator(const string &model_file, const string &weights_file, const string &means_file = "");
|
||||||
|
|
||||||
void simulate(const string &input_file);
|
vector<DType> simulate(const string &input_file);
|
||||||
|
|
||||||
private:
|
private:
|
||||||
caffe::Net<DType> net;
|
caffe::Net<DType> net;
|
||||||
cv::Size input_geometry;
|
cv::Size input_geometry;
|
||||||
|
cv::Mat means;
|
||||||
unsigned int num_channels;
|
unsigned int num_channels;
|
||||||
|
|
||||||
vector<cv::Mat> getWrappedInputLayer();
|
vector<cv::Mat> getWrappedInputLayer();
|
||||||
cv::Mat preprocess(cv::Mat original) const;
|
cv::Mat preprocess(cv::Mat original) const;
|
||||||
|
cv::Mat processMeans(const string &means_file) const;
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ int main(int argc, char *const argv[]) {
|
|||||||
|
|
||||||
Options options = Options::parse(argc, argv);
|
Options options = Options::parse(argc, argv);
|
||||||
|
|
||||||
Simulator simulator(options.model(), options.weights());
|
Simulator simulator(options.model(), options.weights(), options.means());
|
||||||
|
|
||||||
for (const auto &image : options.inputs()) {
|
for (const auto &image : options.inputs()) {
|
||||||
simulator.simulate(image);
|
simulator.simulate(image);
|
||||||
|
|||||||
Reference in New Issue
Block a user