Clean up the loading process.
Now loads one file at a time, reducing the amount of memory needed for large numbers of inputs. Still somewhat heavy for low numbers because of the overhead of loading the model.
This commit is contained in:
@@ -67,18 +67,24 @@ static VisualisationList loadVisualisations(const Options& options)
|
|||||||
{
|
{
|
||||||
using namespace std;
|
using namespace std;
|
||||||
|
|
||||||
auto [layerInfo, layerData] = Simulator::loadSimulationData(options);
|
Simulator simulator(options.model(), options.weights(), options.means());
|
||||||
|
|
||||||
|
const auto layerInfo = simulator.layerInfo();
|
||||||
auto labels = options.labels();
|
auto labels = options.labels();
|
||||||
|
|
||||||
VisualisationList result;
|
VisualisationList result;
|
||||||
|
|
||||||
for (auto &&item : layerData) {
|
auto dumper = options.imageDumper();
|
||||||
|
|
||||||
|
for (auto& input : options.inputs()) {
|
||||||
|
LOG(INFO) << "Simulating " << input;
|
||||||
|
auto item = simulator.simulate(input);
|
||||||
|
|
||||||
vector<unique_ptr<LayerVisualisation>> layers;
|
vector<unique_ptr<LayerVisualisation>> layers;
|
||||||
vector<unique_ptr<Animation>> animations;
|
vector<unique_ptr<Animation>> animations;
|
||||||
LayerData* prevData = nullptr;
|
LayerData* prevData = nullptr;
|
||||||
|
|
||||||
for (LayerData &layer : item) {
|
for (auto &layer : item) {
|
||||||
unique_ptr<LayerVisualisation> layerVisualisation(getVisualisationForLayer(layer, layerInfo.at(layer.name())));
|
unique_ptr<LayerVisualisation> layerVisualisation(getVisualisationForLayer(layer, layerInfo.at(layer.name())));
|
||||||
|
|
||||||
if (prevData != nullptr) {
|
if (prevData != nullptr) {
|
||||||
@@ -88,14 +94,24 @@ static VisualisationList loadVisualisations(const Options& options)
|
|||||||
|
|
||||||
layers.emplace_back(move(layerVisualisation));
|
layers.emplace_back(move(layerVisualisation));
|
||||||
prevData = &layer;
|
prevData = &layer;
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
VisualisationList::value_type dataSet;
|
VisualisationList::value_type dataSet;
|
||||||
|
|
||||||
if (labels) {
|
if (labels) {
|
||||||
|
auto &last = *item.rbegin();
|
||||||
|
auto bestIndex = std::distance(last.data(), max_element(last.data(), last.data() + last.numEntries()));
|
||||||
|
LOG(INFO) << "Got answer: " << labels->at(bestIndex) << endl;
|
||||||
animations.emplace_back(new LabelVisualisation(layers.rbegin()->get()->nodePositions(), *prevData, labels.value()));
|
animations.emplace_back(new LabelVisualisation(layers.rbegin()->get()->nodePositions(), *prevData, labels.value()));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if (dumper) {
|
||||||
|
for (auto &layer : item) {
|
||||||
|
dumper->dump(layer);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
for (auto i = 0u; i < layers.size(); ++i) {
|
for (auto i = 0u; i < layers.size(); ++i) {
|
||||||
auto interaction = i < animations.size() ? move(animations[i]) : nullptr;
|
auto interaction = i < animations.size() ? move(animations[i]) : nullptr;
|
||||||
dataSet.emplace_back(move(layers[i]), move(interaction));
|
dataSet.emplace_back(move(layers[i]), move(interaction));
|
||||||
|
|||||||
@@ -231,34 +231,3 @@ const map<string, LayerInfo> & Simulator::layerInfo() const
|
|||||||
{
|
{
|
||||||
return pImpl->layerInfo();
|
return pImpl->layerInfo();
|
||||||
}
|
}
|
||||||
|
|
||||||
std::pair<std::map<std::string, LayerInfo>, std::vector<std::vector<LayerData>>>
|
|
||||||
fmri::Simulator::loadSimulationData(const Options &options)
|
|
||||||
{
|
|
||||||
Simulator simulator(options.model(), options.weights(), options.means());
|
|
||||||
|
|
||||||
std::vector<std::vector<LayerData>> results;
|
|
||||||
transform(options.inputs().begin(), options.inputs().end(), back_inserter(results), [&simulator] (auto& x) {
|
|
||||||
return simulator.simulate(x);
|
|
||||||
});
|
|
||||||
|
|
||||||
auto dumper = options.imageDumper();
|
|
||||||
if (dumper) {
|
|
||||||
for (auto &layer : *results.begin()) {
|
|
||||||
dumper->dump(layer);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const auto optLabels = options.labels();
|
|
||||||
|
|
||||||
if (optLabels) {
|
|
||||||
auto& labels = *optLabels;
|
|
||||||
for (const auto& result : results) {
|
|
||||||
auto &last = *result.rbegin();
|
|
||||||
auto bestIndex = std::distance(last.data(), max_element(last.data(), last.data() + last.numEntries()));
|
|
||||||
LOG(INFO) << "Got answer: " << labels[bestIndex] << endl;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return make_pair(simulator.layerInfo(), std::move(results));
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -19,8 +19,6 @@ namespace fmri {
|
|||||||
vector<LayerData> simulate(const string &input_file);
|
vector<LayerData> simulate(const string &input_file);
|
||||||
const std::map<std::string, LayerInfo>& layerInfo() const;
|
const std::map<std::string, LayerInfo>& layerInfo() const;
|
||||||
|
|
||||||
static std::pair<std::map<std::string, LayerInfo>, std::vector<std::vector<LayerData>>> loadSimulationData(const Options &options);
|
|
||||||
|
|
||||||
private:
|
private:
|
||||||
struct Impl;
|
struct Impl;
|
||||||
std::unique_ptr<Impl> pImpl;
|
std::unique_ptr<Impl> pImpl;
|
||||||
|
|||||||
@@ -93,6 +93,7 @@ static fmri::LayerVisualisation *getAppropriateLayer(const fmri::LayerData &data
|
|||||||
|
|
||||||
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &data, const fmri::LayerInfo &info)
|
fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &data, const fmri::LayerInfo &info)
|
||||||
{
|
{
|
||||||
|
LOG(INFO) << "Loading state visualisation for " << data.name();
|
||||||
auto layer = getAppropriateLayer(data, info);
|
auto layer = getAppropriateLayer(data, info);
|
||||||
layer->setupLayerName(data.name(), info.type());
|
layer->setupLayerName(data.name(), info.type());
|
||||||
|
|
||||||
@@ -102,8 +103,6 @@ fmri::LayerVisualisation *fmri::getVisualisationForLayer(const fmri::LayerData &
|
|||||||
static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, const fmri::LayerInfo &layer,
|
static Animation *getFullyConnectedAnimation(const fmri::LayerData &prevState, const fmri::LayerInfo &layer,
|
||||||
const vector<float> &prevPositions, const vector<float> &curPositions)
|
const vector<float> &prevPositions, const vector<float> &curPositions)
|
||||||
{
|
{
|
||||||
LOG(INFO) << "Computing top interactions for " << layer.name() << endl;
|
|
||||||
|
|
||||||
auto data = prevState.data();
|
auto data = prevState.data();
|
||||||
|
|
||||||
CHECK_GE(layer.parameters().size(), 1) << "Layer should have correct parameters";
|
CHECK_GE(layer.parameters().size(), 1) << "Layer should have correct parameters";
|
||||||
@@ -250,6 +249,7 @@ Animation * fmri::getActivityAnimation(const fmri::LayerData &prevState, const f
|
|||||||
return nullptr;
|
return nullptr;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
LOG(INFO) << "Loading interaction for " << layer.name();
|
||||||
|
|
||||||
switch (layer.type()) {
|
switch (layer.type()) {
|
||||||
case LayerInfo::Type::InnerProduct:
|
case LayerInfo::Type::InnerProduct:
|
||||||
|
|||||||
Reference in New Issue
Block a user