diff --git a/CMakeLists.txt b/CMakeLists.txt index 2cbdcf7..04d471a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -37,3 +37,15 @@ target_include_directories(fmri PUBLIC # Allow the package to be installed install(TARGETS fmri DESTINATION bin) + +# Build instructions for the deinplace tool +find_package(Protobuf REQUIRED) +protobuf_generate_cpp(CAFFE_PROTO_CPP CAFFE_PROTO_HEADERS "src/proto/caffe.proto") +add_executable(deinplace + "src/tools/deinplace.cpp" + ${CAFFE_PROTO_CPP} + ${CAFFE_PROTO_HEADERS} + ) +target_compile_options(deinplace PRIVATE "-Wall" "-Wextra" "-pedantic") +target_link_libraries(deinplace protobuf::libprotobuf) +target_include_directories(deinplace PRIVATE "${CMAKE_CURRENT_BINARY_DIR}") diff --git a/data/models/alexnet/model-dedup.prototxt b/data/models/alexnet/model-dedup.prototxt index 6b71082..02fca5d 100644 --- a/data/models/alexnet/model-dedup.prototxt +++ b/data/models/alexnet/model-dedup.prototxt @@ -18,12 +18,12 @@ layer { bottom: "data" top: "conv1" param { - lr_mult: 1.0 - decay_mult: 1.0 + lr_mult: 1 + decay_mult: 1 } param { - lr_mult: 2.0 - decay_mult: 0.0 + lr_mult: 2 + decay_mult: 0 } convolution_param { num_output: 96 @@ -65,12 +65,12 @@ layer { bottom: "pool1" top: "conv2" param { - lr_mult: 1.0 - decay_mult: 1.0 + lr_mult: 1 + decay_mult: 1 } param { - lr_mult: 2.0 - decay_mult: 0.0 + lr_mult: 2 + decay_mult: 0 } convolution_param { num_output: 256 @@ -113,12 +113,12 @@ layer { bottom: "pool2" top: "conv3" param { - lr_mult: 1.0 - decay_mult: 1.0 + lr_mult: 1 + decay_mult: 1 } param { - lr_mult: 2.0 - decay_mult: 0.0 + lr_mult: 2 + decay_mult: 0 } convolution_param { num_output: 384 @@ -138,12 +138,12 @@ layer { bottom: "relu3" top: "conv4" param { - lr_mult: 1.0 - decay_mult: 1.0 + lr_mult: 1 + decay_mult: 1 } param { - lr_mult: 2.0 - decay_mult: 0.0 + lr_mult: 2 + decay_mult: 0 } convolution_param { num_output: 384 @@ -164,12 +164,12 @@ layer { bottom: "relu4" top: "conv5" param { - lr_mult: 1.0 - decay_mult: 1.0 + lr_mult: 1 + decay_mult: 1 } param { - lr_mult: 2.0 - decay_mult: 0.0 + lr_mult: 2 + decay_mult: 0 } convolution_param { num_output: 256 @@ -201,12 +201,12 @@ layer { bottom: "pool5" top: "fc6" param { - lr_mult: 1.0 - decay_mult: 1.0 + lr_mult: 1 + decay_mult: 1 } param { - lr_mult: 2.0 - decay_mult: 0.0 + lr_mult: 2 + decay_mult: 0 } inner_product_param { num_output: 4096 @@ -233,12 +233,12 @@ layer { bottom: "drop6" top: "fc7" param { - lr_mult: 1.0 - decay_mult: 1.0 + lr_mult: 1 + decay_mult: 1 } param { - lr_mult: 2.0 - decay_mult: 0.0 + lr_mult: 2 + decay_mult: 0 } inner_product_param { num_output: 4096 @@ -265,12 +265,12 @@ layer { bottom: "drop7" top: "fc8" param { - lr_mult: 1.0 - decay_mult: 1.0 + lr_mult: 1 + decay_mult: 1 } param { - lr_mult: 2.0 - decay_mult: 0.0 + lr_mult: 2 + decay_mult: 0 } inner_product_param { num_output: 1000 diff --git a/tools/proto/caffe.proto b/src/proto/caffe.proto similarity index 100% rename from tools/proto/caffe.proto rename to src/proto/caffe.proto diff --git a/src/tools/deinplace.cpp b/src/tools/deinplace.cpp new file mode 100644 index 0000000..b346095 --- /dev/null +++ b/src/tools/deinplace.cpp @@ -0,0 +1,158 @@ +#include +#include +#include +#include +#include +#include +#include "caffe.pb.h" + +static struct +{ + bool verbose = false; + std::string_view output_filename; + std::string_view input_filename; +} options; + +template +void verbose_print(Args &&... args) +{ + if (!options.verbose) return; + + (std::cerr << ... << args) << '\n'; +} +stdin +void show_help(std::string_view name, int exit_code) +{ + std::cerr << "Usage: " << name << "[OPTIONS] [FILENAME]\n" + "\n" + "Valid options:\n" + "-h\tshow this message\n" + "-v\tshow debug messages\n" + "-o OUTPUT\t write output to OUTPUT\n" + "FILENAME\t read from FILENAME instead of stdin\n" << std::endl; + + exit(exit_code); +} + +int get_file_descriptor(std::string_view file_name, int mode) +{ + int fd = open(file_name.data(), mode); + if (fd < 0) { + perror("Failed to open file"); + exit(1); + } + + return fd; +} + +void read_options(int argc, char **argv) +{ + for (char c; (c = static_cast(getopt(argc, argv, "hvo:"))) != -1;) { + switch (c) { + case 'v': + options.verbose = true; + break; + case 'h': + show_help(argv[0], 0); + break; + case 'o': + options.output_filename = optarg; + break; + case '?': + show_help(argv[0], 1); + break; + } + } + + if (argc > optind) { + options.input_filename = argv[optind]; + } +} + +void read_network(caffe::NetParameter &network) +{ + using namespace google::protobuf::io; + std::unique_ptr input; + if (!options.input_filename.empty()) { + verbose_print("Reading network from", options.input_filename); + + auto fd = get_file_descriptor(options.input_filename, O_RDONLY); + FileInputStream *i = new FileInputStream(fd); + i->SetCloseOnDelete(true); + input.reset(i); + } else { + verbose_print("Reading network from stdin"); + input = std::make_unique(&std::cin); + } + + auto result = google::protobuf::TextFormat::Parse(input.get(), &network); + if (!result) { + std::cerr << "Error reading network file!" << std::endl; + exit(2); + } +} + +template +void patch_layers(google::protobuf::RepeatedPtrField *layers) +{ + std::map outputs; + + for (LayerType &layer : *layers) { + + for (int i = 0; i < layer.bottom_size(); ++i) { + if (auto it = outputs.find(layer.bottom(i)); it != outputs.end()) { + verbose_print(layer.name(), " reads from in-place ", layer.bottom(i), ", rewriting."); + *layer.mutable_bottom(i) = it->second; + } + } + + for (int i = 0; i < layer.top_size(); ++i) { + if (auto it = outputs.find(layer.top(i)); it != outputs.end()) { + verbose_print(layer.name(), " works in-place rewriting."); + it->second = layer.name(); + *layer.mutable_top(i) = layer.name(); + } else { + outputs[layer.name()] = layer.name(); + } + } + } +} + +void patch_network(caffe::NetParameter &network) +{ + patch_layers(network.mutable_layer()); + patch_layers(network.mutable_layers()); +} + +void write_network(caffe::NetParameter &network) +{ + using namespace google::protobuf::io; + std::unique_ptr output; + if (!options.output_filename.empty()) { + verbose_print("Writing network to ", options.output_filename); + + auto fd = get_file_descriptor(options.output_filename, O_WRONLY | O_CREAT | O_TRUNC); + FileOutputStream *i = new FileOutputStream(fd); + i->SetCloseOnDelete(true); + output.reset(i); + } else { + verbose_print("Writing network to stdout"); + output = std::make_unique(&std::cout); + } + + google::protobuf::TextFormat::Print(network, output.get()); +} + +int main(int argc, char **argv) +{ + GOOGLE_PROTOBUF_VERIFY_VERSION; + + read_options(argc, argv); + caffe::NetParameter network; + read_network(network); + patch_network(network); + write_network(network); + + // Deallocate all global protobuf objects. + google::protobuf::ShutdownProtobufLibrary(); +} diff --git a/tools/Makefile b/tools/Makefile deleted file mode 100644 index 2517e65..0000000 --- a/tools/Makefile +++ /dev/null @@ -1,7 +0,0 @@ -MKDIR=mkdir -p - -all: generated - -generated: $(wildcard proto/*.proto) - $(MKDIR) $@ - protoc --python_out="$@" $^ diff --git a/tools/deinplace.py b/tools/deinplace.py deleted file mode 100755 index 52f5972..0000000 --- a/tools/deinplace.py +++ /dev/null @@ -1,64 +0,0 @@ -#!/usr/bin/env python3 - -import argparse -import sys -import generated.proto.caffe_pb2 as pb2 -import google.protobuf.text_format as text_format - - -def eprint(*args, **kwargs): - print(*args, file=sys.stderr, **kwargs) - - -def get_args(): - parser = argparse.ArgumentParser(description='De-in-place Caffe networks') - parser.add_argument('file', type=argparse.FileType('r')) - parser.add_argument('-v', '--verbose', action='store_true', - help='Be more verbose.') - parser.add_argument('-o', '--output', type=argparse.FileType('w'), - default=sys.stdout, - help='File to write result to. Default stdout') - - return parser.parse_args() - - -def load_net(args): - with args.file as f: - data = "".join(f.readlines()) - net = pb2.NetParameter() - text_format.Merge(data, net) - - return net - - -def deinplace(args, net): - outputs = {} - layers = net.layer if net.layer else net.layers - for layer in layers: - for idx, bottom in enumerate(layer.bottom): - if bottom in outputs and bottom != outputs[bottom]: - if args.verbose: - eprint(layer.name, 'reads from in-place layer, rewriting…') - - layer.bottom[idx] = outputs[bottom] - - for idx, top in enumerate(layer.top): - if top in outputs: - if args.verbose: - eprint(layer.name, 'works in-place, rewriting…') - outputs[top] = layer.name - layer.top[idx] = layer.name - else: - outputs[top] = top - - -def main(): - args = get_args() - net = load_net(args) - deinplace(args, net) - - args.output.write(text_format.MessageToString(net)) - - -if __name__ == '__main__': - main()