Create a python script for rewriting in-place layers.
Using in-place layers makes it impossible to see the state of an individual layer, so this script rewrites them into each having their own output layer.
This commit is contained in:
7
tools/Makefile
Normal file
7
tools/Makefile
Normal file
@@ -0,0 +1,7 @@
|
||||
MKDIR=mkdir -p
|
||||
|
||||
all: generated
|
||||
|
||||
generated: $(wildcard proto/*.proto)
|
||||
$(MKDIR) $@
|
||||
protoc --python_out="$@" $^
|
||||
63
tools/deinplace.py
Executable file
63
tools/deinplace.py
Executable file
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import generated.proto.caffe_pb2 as pb2
|
||||
import google.protobuf.text_format as text_format
|
||||
|
||||
|
||||
def eprint(*args, **kwargs):
|
||||
print(*args, file=sys.stderr, **kwargs)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='De-in-place Caffe networks')
|
||||
parser.add_argument('file', type=argparse.FileType('r'))
|
||||
parser.add_argument('-v', '--verbose', action='store_true',
|
||||
help='Be more verbose.')
|
||||
parser.add_argument('-o', '--output', type=argparse.FileType('w'),
|
||||
default=sys.stdout,
|
||||
help='File to write result to. Default stdout')
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_net(args):
|
||||
with args.file as f:
|
||||
data = "".join(f.readlines())
|
||||
net = pb2.NetParameter()
|
||||
text_format.Merge(data, net)
|
||||
|
||||
return net
|
||||
|
||||
|
||||
def deinplace(args, net):
|
||||
outputs = {}
|
||||
for layer in net.layer:
|
||||
for idx, bottom in enumerate(layer.bottom):
|
||||
if bottom in outputs and bottom != outputs[bottom]:
|
||||
if args.verbose:
|
||||
eprint(layer.name, 'reads from in-place layer, rewriting…')
|
||||
|
||||
layer.bottom[idx] = outputs[bottom]
|
||||
|
||||
for idx, top in enumerate(layer.top):
|
||||
if top in outputs:
|
||||
if args.verbose:
|
||||
eprint(layer.name, 'works in-place, rewriting…')
|
||||
outputs[top] = layer.name
|
||||
layer.top[idx] = layer.name
|
||||
else:
|
||||
outputs[top] = top
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
net = load_net(args)
|
||||
deinplace(args, net)
|
||||
|
||||
args.output.write(text_format.MessageToString(net))
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
tools/model-dedup.prototxt
Normal file
0
tools/model-dedup.prototxt
Normal file
1412
tools/proto/caffe.proto
Normal file
1412
tools/proto/caffe.proto
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user