Rewrite the deinplace tool in C++.
More consistency for a better build process.
This commit is contained in:
@@ -37,3 +37,15 @@ target_include_directories(fmri PUBLIC
|
|||||||
|
|
||||||
# Allow the package to be installed
|
# Allow the package to be installed
|
||||||
install(TARGETS fmri DESTINATION bin)
|
install(TARGETS fmri DESTINATION bin)
|
||||||
|
|
||||||
|
# Build instructions for the deinplace tool
|
||||||
|
find_package(Protobuf REQUIRED)
|
||||||
|
protobuf_generate_cpp(CAFFE_PROTO_CPP CAFFE_PROTO_HEADERS "src/proto/caffe.proto")
|
||||||
|
add_executable(deinplace
|
||||||
|
"src/tools/deinplace.cpp"
|
||||||
|
${CAFFE_PROTO_CPP}
|
||||||
|
${CAFFE_PROTO_HEADERS}
|
||||||
|
)
|
||||||
|
target_compile_options(deinplace PRIVATE "-Wall" "-Wextra" "-pedantic")
|
||||||
|
target_link_libraries(deinplace protobuf::libprotobuf)
|
||||||
|
target_include_directories(deinplace PRIVATE "${CMAKE_CURRENT_BINARY_DIR}")
|
||||||
|
|||||||
@@ -18,12 +18,12 @@ layer {
|
|||||||
bottom: "data"
|
bottom: "data"
|
||||||
top: "conv1"
|
top: "conv1"
|
||||||
param {
|
param {
|
||||||
lr_mult: 1.0
|
lr_mult: 1
|
||||||
decay_mult: 1.0
|
decay_mult: 1
|
||||||
}
|
}
|
||||||
param {
|
param {
|
||||||
lr_mult: 2.0
|
lr_mult: 2
|
||||||
decay_mult: 0.0
|
decay_mult: 0
|
||||||
}
|
}
|
||||||
convolution_param {
|
convolution_param {
|
||||||
num_output: 96
|
num_output: 96
|
||||||
@@ -65,12 +65,12 @@ layer {
|
|||||||
bottom: "pool1"
|
bottom: "pool1"
|
||||||
top: "conv2"
|
top: "conv2"
|
||||||
param {
|
param {
|
||||||
lr_mult: 1.0
|
lr_mult: 1
|
||||||
decay_mult: 1.0
|
decay_mult: 1
|
||||||
}
|
}
|
||||||
param {
|
param {
|
||||||
lr_mult: 2.0
|
lr_mult: 2
|
||||||
decay_mult: 0.0
|
decay_mult: 0
|
||||||
}
|
}
|
||||||
convolution_param {
|
convolution_param {
|
||||||
num_output: 256
|
num_output: 256
|
||||||
@@ -113,12 +113,12 @@ layer {
|
|||||||
bottom: "pool2"
|
bottom: "pool2"
|
||||||
top: "conv3"
|
top: "conv3"
|
||||||
param {
|
param {
|
||||||
lr_mult: 1.0
|
lr_mult: 1
|
||||||
decay_mult: 1.0
|
decay_mult: 1
|
||||||
}
|
}
|
||||||
param {
|
param {
|
||||||
lr_mult: 2.0
|
lr_mult: 2
|
||||||
decay_mult: 0.0
|
decay_mult: 0
|
||||||
}
|
}
|
||||||
convolution_param {
|
convolution_param {
|
||||||
num_output: 384
|
num_output: 384
|
||||||
@@ -138,12 +138,12 @@ layer {
|
|||||||
bottom: "relu3"
|
bottom: "relu3"
|
||||||
top: "conv4"
|
top: "conv4"
|
||||||
param {
|
param {
|
||||||
lr_mult: 1.0
|
lr_mult: 1
|
||||||
decay_mult: 1.0
|
decay_mult: 1
|
||||||
}
|
}
|
||||||
param {
|
param {
|
||||||
lr_mult: 2.0
|
lr_mult: 2
|
||||||
decay_mult: 0.0
|
decay_mult: 0
|
||||||
}
|
}
|
||||||
convolution_param {
|
convolution_param {
|
||||||
num_output: 384
|
num_output: 384
|
||||||
@@ -164,12 +164,12 @@ layer {
|
|||||||
bottom: "relu4"
|
bottom: "relu4"
|
||||||
top: "conv5"
|
top: "conv5"
|
||||||
param {
|
param {
|
||||||
lr_mult: 1.0
|
lr_mult: 1
|
||||||
decay_mult: 1.0
|
decay_mult: 1
|
||||||
}
|
}
|
||||||
param {
|
param {
|
||||||
lr_mult: 2.0
|
lr_mult: 2
|
||||||
decay_mult: 0.0
|
decay_mult: 0
|
||||||
}
|
}
|
||||||
convolution_param {
|
convolution_param {
|
||||||
num_output: 256
|
num_output: 256
|
||||||
@@ -201,12 +201,12 @@ layer {
|
|||||||
bottom: "pool5"
|
bottom: "pool5"
|
||||||
top: "fc6"
|
top: "fc6"
|
||||||
param {
|
param {
|
||||||
lr_mult: 1.0
|
lr_mult: 1
|
||||||
decay_mult: 1.0
|
decay_mult: 1
|
||||||
}
|
}
|
||||||
param {
|
param {
|
||||||
lr_mult: 2.0
|
lr_mult: 2
|
||||||
decay_mult: 0.0
|
decay_mult: 0
|
||||||
}
|
}
|
||||||
inner_product_param {
|
inner_product_param {
|
||||||
num_output: 4096
|
num_output: 4096
|
||||||
@@ -233,12 +233,12 @@ layer {
|
|||||||
bottom: "drop6"
|
bottom: "drop6"
|
||||||
top: "fc7"
|
top: "fc7"
|
||||||
param {
|
param {
|
||||||
lr_mult: 1.0
|
lr_mult: 1
|
||||||
decay_mult: 1.0
|
decay_mult: 1
|
||||||
}
|
}
|
||||||
param {
|
param {
|
||||||
lr_mult: 2.0
|
lr_mult: 2
|
||||||
decay_mult: 0.0
|
decay_mult: 0
|
||||||
}
|
}
|
||||||
inner_product_param {
|
inner_product_param {
|
||||||
num_output: 4096
|
num_output: 4096
|
||||||
@@ -265,12 +265,12 @@ layer {
|
|||||||
bottom: "drop7"
|
bottom: "drop7"
|
||||||
top: "fc8"
|
top: "fc8"
|
||||||
param {
|
param {
|
||||||
lr_mult: 1.0
|
lr_mult: 1
|
||||||
decay_mult: 1.0
|
decay_mult: 1
|
||||||
}
|
}
|
||||||
param {
|
param {
|
||||||
lr_mult: 2.0
|
lr_mult: 2
|
||||||
decay_mult: 0.0
|
decay_mult: 0
|
||||||
}
|
}
|
||||||
inner_product_param {
|
inner_product_param {
|
||||||
num_output: 1000
|
num_output: 1000
|
||||||
|
|||||||
158
src/tools/deinplace.cpp
Normal file
158
src/tools/deinplace.cpp
Normal file
@@ -0,0 +1,158 @@
|
|||||||
|
#include <getopt.h>
|
||||||
|
#include <fstream>
|
||||||
|
#include <memory>
|
||||||
|
#include <fcntl.h>
|
||||||
|
#include <google/protobuf/text_format.h>
|
||||||
|
#include <google/protobuf/io/zero_copy_stream_impl.h>
|
||||||
|
#include "caffe.pb.h"
|
||||||
|
|
||||||
|
static struct
|
||||||
|
{
|
||||||
|
bool verbose = false;
|
||||||
|
std::string_view output_filename;
|
||||||
|
std::string_view input_filename;
|
||||||
|
} options;
|
||||||
|
|
||||||
|
template<typename... Args>
|
||||||
|
void verbose_print(Args &&... args)
|
||||||
|
{
|
||||||
|
if (!options.verbose) return;
|
||||||
|
|
||||||
|
(std::cerr << ... << args) << '\n';
|
||||||
|
}
|
||||||
|
stdin
|
||||||
|
void show_help(std::string_view name, int exit_code)
|
||||||
|
{
|
||||||
|
std::cerr << "Usage: " << name << "[OPTIONS] [FILENAME]\n"
|
||||||
|
"\n"
|
||||||
|
"Valid options:\n"
|
||||||
|
"-h\tshow this message\n"
|
||||||
|
"-v\tshow debug messages\n"
|
||||||
|
"-o OUTPUT\t write output to OUTPUT\n"
|
||||||
|
"FILENAME\t read from FILENAME instead of stdin\n" << std::endl;
|
||||||
|
|
||||||
|
exit(exit_code);
|
||||||
|
}
|
||||||
|
|
||||||
|
int get_file_descriptor(std::string_view file_name, int mode)
|
||||||
|
{
|
||||||
|
int fd = open(file_name.data(), mode);
|
||||||
|
if (fd < 0) {
|
||||||
|
perror("Failed to open file");
|
||||||
|
exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
return fd;
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_options(int argc, char **argv)
|
||||||
|
{
|
||||||
|
for (char c; (c = static_cast<char>(getopt(argc, argv, "hvo:"))) != -1;) {
|
||||||
|
switch (c) {
|
||||||
|
case 'v':
|
||||||
|
options.verbose = true;
|
||||||
|
break;
|
||||||
|
case 'h':
|
||||||
|
show_help(argv[0], 0);
|
||||||
|
break;
|
||||||
|
case 'o':
|
||||||
|
options.output_filename = optarg;
|
||||||
|
break;
|
||||||
|
case '?':
|
||||||
|
show_help(argv[0], 1);
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if (argc > optind) {
|
||||||
|
options.input_filename = argv[optind];
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void read_network(caffe::NetParameter &network)
|
||||||
|
{
|
||||||
|
using namespace google::protobuf::io;
|
||||||
|
std::unique_ptr<ZeroCopyInputStream> input;
|
||||||
|
if (!options.input_filename.empty()) {
|
||||||
|
verbose_print("Reading network from", options.input_filename);
|
||||||
|
|
||||||
|
auto fd = get_file_descriptor(options.input_filename, O_RDONLY);
|
||||||
|
FileInputStream *i = new FileInputStream(fd);
|
||||||
|
i->SetCloseOnDelete(true);
|
||||||
|
input.reset(i);
|
||||||
|
} else {
|
||||||
|
verbose_print("Reading network from stdin");
|
||||||
|
input = std::make_unique<IstreamInputStream>(&std::cin);
|
||||||
|
}
|
||||||
|
|
||||||
|
auto result = google::protobuf::TextFormat::Parse(input.get(), &network);
|
||||||
|
if (!result) {
|
||||||
|
std::cerr << "Error reading network file!" << std::endl;
|
||||||
|
exit(2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
template<typename LayerType>
|
||||||
|
void patch_layers(google::protobuf::RepeatedPtrField<LayerType> *layers)
|
||||||
|
{
|
||||||
|
std::map<std::string, std::string> outputs;
|
||||||
|
|
||||||
|
for (LayerType &layer : *layers) {
|
||||||
|
|
||||||
|
for (int i = 0; i < layer.bottom_size(); ++i) {
|
||||||
|
if (auto it = outputs.find(layer.bottom(i)); it != outputs.end()) {
|
||||||
|
verbose_print(layer.name(), " reads from in-place ", layer.bottom(i), ", rewriting.");
|
||||||
|
*layer.mutable_bottom(i) = it->second;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (int i = 0; i < layer.top_size(); ++i) {
|
||||||
|
if (auto it = outputs.find(layer.top(i)); it != outputs.end()) {
|
||||||
|
verbose_print(layer.name(), " works in-place rewriting.");
|
||||||
|
it->second = layer.name();
|
||||||
|
*layer.mutable_top(i) = layer.name();
|
||||||
|
} else {
|
||||||
|
outputs[layer.name()] = layer.name();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
void patch_network(caffe::NetParameter &network)
|
||||||
|
{
|
||||||
|
patch_layers(network.mutable_layer());
|
||||||
|
patch_layers(network.mutable_layers());
|
||||||
|
}
|
||||||
|
|
||||||
|
void write_network(caffe::NetParameter &network)
|
||||||
|
{
|
||||||
|
using namespace google::protobuf::io;
|
||||||
|
std::unique_ptr<ZeroCopyOutputStream> output;
|
||||||
|
if (!options.output_filename.empty()) {
|
||||||
|
verbose_print("Writing network to ", options.output_filename);
|
||||||
|
|
||||||
|
auto fd = get_file_descriptor(options.output_filename, O_WRONLY | O_CREAT | O_TRUNC);
|
||||||
|
FileOutputStream *i = new FileOutputStream(fd);
|
||||||
|
i->SetCloseOnDelete(true);
|
||||||
|
output.reset(i);
|
||||||
|
} else {
|
||||||
|
verbose_print("Writing network to stdout");
|
||||||
|
output = std::make_unique<OstreamOutputStream>(&std::cout);
|
||||||
|
}
|
||||||
|
|
||||||
|
google::protobuf::TextFormat::Print(network, output.get());
|
||||||
|
}
|
||||||
|
|
||||||
|
int main(int argc, char **argv)
|
||||||
|
{
|
||||||
|
GOOGLE_PROTOBUF_VERIFY_VERSION;
|
||||||
|
|
||||||
|
read_options(argc, argv);
|
||||||
|
caffe::NetParameter network;
|
||||||
|
read_network(network);
|
||||||
|
patch_network(network);
|
||||||
|
write_network(network);
|
||||||
|
|
||||||
|
// Deallocate all global protobuf objects.
|
||||||
|
google::protobuf::ShutdownProtobufLibrary();
|
||||||
|
}
|
||||||
@@ -1,7 +0,0 @@
|
|||||||
MKDIR=mkdir -p
|
|
||||||
|
|
||||||
all: generated
|
|
||||||
|
|
||||||
generated: $(wildcard proto/*.proto)
|
|
||||||
$(MKDIR) $@
|
|
||||||
protoc --python_out="$@" $^
|
|
||||||
@@ -1,64 +0,0 @@
|
|||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
import argparse
|
|
||||||
import sys
|
|
||||||
import generated.proto.caffe_pb2 as pb2
|
|
||||||
import google.protobuf.text_format as text_format
|
|
||||||
|
|
||||||
|
|
||||||
def eprint(*args, **kwargs):
|
|
||||||
print(*args, file=sys.stderr, **kwargs)
|
|
||||||
|
|
||||||
|
|
||||||
def get_args():
|
|
||||||
parser = argparse.ArgumentParser(description='De-in-place Caffe networks')
|
|
||||||
parser.add_argument('file', type=argparse.FileType('r'))
|
|
||||||
parser.add_argument('-v', '--verbose', action='store_true',
|
|
||||||
help='Be more verbose.')
|
|
||||||
parser.add_argument('-o', '--output', type=argparse.FileType('w'),
|
|
||||||
default=sys.stdout,
|
|
||||||
help='File to write result to. Default stdout')
|
|
||||||
|
|
||||||
return parser.parse_args()
|
|
||||||
|
|
||||||
|
|
||||||
def load_net(args):
|
|
||||||
with args.file as f:
|
|
||||||
data = "".join(f.readlines())
|
|
||||||
net = pb2.NetParameter()
|
|
||||||
text_format.Merge(data, net)
|
|
||||||
|
|
||||||
return net
|
|
||||||
|
|
||||||
|
|
||||||
def deinplace(args, net):
|
|
||||||
outputs = {}
|
|
||||||
layers = net.layer if net.layer else net.layers
|
|
||||||
for layer in layers:
|
|
||||||
for idx, bottom in enumerate(layer.bottom):
|
|
||||||
if bottom in outputs and bottom != outputs[bottom]:
|
|
||||||
if args.verbose:
|
|
||||||
eprint(layer.name, 'reads from in-place layer, rewriting…')
|
|
||||||
|
|
||||||
layer.bottom[idx] = outputs[bottom]
|
|
||||||
|
|
||||||
for idx, top in enumerate(layer.top):
|
|
||||||
if top in outputs:
|
|
||||||
if args.verbose:
|
|
||||||
eprint(layer.name, 'works in-place, rewriting…')
|
|
||||||
outputs[top] = layer.name
|
|
||||||
layer.top[idx] = layer.name
|
|
||||||
else:
|
|
||||||
outputs[top] = top
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
|
||||||
args = get_args()
|
|
||||||
net = load_net(args)
|
|
||||||
deinplace(args, net)
|
|
||||||
|
|
||||||
args.output.write(text_format.MessageToString(net))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
|
||||||
main()
|
|
||||||
Reference in New Issue
Block a user