Create a python script for rewriting in-place layers.
Using in-place layers makes it impossible to see the state of an individual layer, so this script rewrites them into each having their own output layer.
This commit is contained in:
1
.gitignore
vendored
1
.gitignore
vendored
@@ -1 +1,2 @@
|
|||||||
build
|
build
|
||||||
|
tools/generated
|
||||||
|
|||||||
7
tools/Makefile
Normal file
7
tools/Makefile
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
MKDIR=mkdir -p
|
||||||
|
|
||||||
|
all: generated
|
||||||
|
|
||||||
|
generated: $(wildcard proto/*.proto)
|
||||||
|
$(MKDIR) $@
|
||||||
|
protoc --python_out="$@" $^
|
||||||
63
tools/deinplace.py
Executable file
63
tools/deinplace.py
Executable file
@@ -0,0 +1,63 @@
|
|||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
import argparse
|
||||||
|
import sys
|
||||||
|
import generated.proto.caffe_pb2 as pb2
|
||||||
|
import google.protobuf.text_format as text_format
|
||||||
|
|
||||||
|
|
||||||
|
def eprint(*args, **kwargs):
|
||||||
|
print(*args, file=sys.stderr, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def get_args():
|
||||||
|
parser = argparse.ArgumentParser(description='De-in-place Caffe networks')
|
||||||
|
parser.add_argument('file', type=argparse.FileType('r'))
|
||||||
|
parser.add_argument('-v', '--verbose', action='store_true',
|
||||||
|
help='Be more verbose.')
|
||||||
|
parser.add_argument('-o', '--output', type=argparse.FileType('w'),
|
||||||
|
default=sys.stdout,
|
||||||
|
help='File to write result to. Default stdout')
|
||||||
|
|
||||||
|
return parser.parse_args()
|
||||||
|
|
||||||
|
|
||||||
|
def load_net(args):
|
||||||
|
with args.file as f:
|
||||||
|
data = "".join(f.readlines())
|
||||||
|
net = pb2.NetParameter()
|
||||||
|
text_format.Merge(data, net)
|
||||||
|
|
||||||
|
return net
|
||||||
|
|
||||||
|
|
||||||
|
def deinplace(args, net):
|
||||||
|
outputs = {}
|
||||||
|
for layer in net.layer:
|
||||||
|
for idx, bottom in enumerate(layer.bottom):
|
||||||
|
if bottom in outputs and bottom != outputs[bottom]:
|
||||||
|
if args.verbose:
|
||||||
|
eprint(layer.name, 'reads from in-place layer, rewriting…')
|
||||||
|
|
||||||
|
layer.bottom[idx] = outputs[bottom]
|
||||||
|
|
||||||
|
for idx, top in enumerate(layer.top):
|
||||||
|
if top in outputs:
|
||||||
|
if args.verbose:
|
||||||
|
eprint(layer.name, 'works in-place, rewriting…')
|
||||||
|
outputs[top] = layer.name
|
||||||
|
layer.top[idx] = layer.name
|
||||||
|
else:
|
||||||
|
outputs[top] = top
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
args = get_args()
|
||||||
|
net = load_net(args)
|
||||||
|
deinplace(args, net)
|
||||||
|
|
||||||
|
args.output.write(text_format.MessageToString(net))
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
0
tools/model-dedup.prototxt
Normal file
0
tools/model-dedup.prototxt
Normal file
1412
tools/proto/caffe.proto
Normal file
1412
tools/proto/caffe.proto
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user