sustaining_gazes/lib/3rdParty/dlib/include/dlib/svm/sequence_segmenter.h
2016-04-28 15:40:36 -04:00

468 lines
17 KiB
C++

// Copyright (C) 2013 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_SEQUENCE_SeGMENTER_H_h_
#define DLIB_SEQUENCE_SeGMENTER_H_h_
#include "sequence_segmenter_abstract.h"
#include "../matrix.h"
#include "sequence_labeler.h"
#include <vector>
namespace dlib
{
// This namespace contains implementation details for the sequence_segmenter.
namespace impl_ss
{
// ------------------------------------------------------------------------------------
// BIO/BILOU labels
const unsigned int BEGIN = 0;
const unsigned int INSIDE = 1;
const unsigned int OUTSIDE = 2;
const unsigned int LAST = 3;
const unsigned int UNIT = 4;
// ------------------------------------------------------------------------------------
template <typename ss_feature_extractor>
class feature_extractor
{
/*!
WHAT THIS OBJECT REPRESENTS
This is a feature extractor for a sequence_labeler. It serves to map
the interface defined by a sequence_labeler into the kind of interface
defined for a sequence_segmenter.
!*/
public:
typedef typename ss_feature_extractor::sequence_type sequence_type;
ss_feature_extractor fe;
feature_extractor() {}
feature_extractor(const ss_feature_extractor& ss_fe_) : fe(ss_fe_) {}
unsigned long num_nonnegative_weights (
) const
{
const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
if (ss_feature_extractor::allow_negative_weights)
{
return 0;
}
else
{
// We make everything non-negative except for the label transition
// and bias features.
return num_features() - NL*NL - NL;
}
}
friend void serialize(const feature_extractor& item, std::ostream& out)
{
serialize(item.fe, out);
}
friend void deserialize(feature_extractor& item, std::istream& in)
{
deserialize(item.fe, in);
}
unsigned long num_features() const
{
const unsigned long NL = ss_feature_extractor::use_BIO_model ? 3 : 5;
if (ss_feature_extractor::use_high_order_features)
return NL + NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size();
else
return NL + NL*NL + NL*fe.num_features()*fe.window_size();
}
unsigned long order() const
{
return 1;
}
unsigned long num_labels() const
{
if (ss_feature_extractor::use_BIO_model)
return 3;
else
return 5;
}
private:
template <typename feature_setter>
struct dot_functor
{
/*!
WHAT THIS OBJECT REPRESENTS
This class wraps the feature_setter used by a sequence_labeler
and turns it into the kind needed by a sequence_segmenter.
!*/
dot_functor(feature_setter& set_feature_, unsigned long offset_) :
set_feature(set_feature_), offset(offset_) {}
feature_setter& set_feature;
unsigned long offset;
inline void operator() (
unsigned long feat_index
)
{
set_feature(offset+feat_index);
}
inline void operator() (
unsigned long feat_index,
double feat_value
)
{
set_feature(offset+feat_index, feat_value);
}
};
public:
template <typename EXP>
bool reject_labeling (
const sequence_type& x,
const matrix_exp<EXP>& y,
unsigned long pos
) const
{
if (ss_feature_extractor::use_BIO_model)
{
// Don't allow BIO label patterns that don't correspond to a sensical
// segmentation.
if (y.size() > 1 && y(0) == INSIDE && y(1) == OUTSIDE)
return true;
if (y.size() == 1 && y(0) == INSIDE)
return true;
}
else
{
// Don't allow BILOU label patterns that don't correspond to a sensical
// segmentation.
if (y.size() > 1)
{
if (y(1) == BEGIN && y(0) == OUTSIDE)
return true;
if (y(1) == BEGIN && y(0) == UNIT)
return true;
if (y(1) == BEGIN && y(0) == BEGIN)
return true;
if (y(1) == INSIDE && y(0) == BEGIN)
return true;
if (y(1) == INSIDE && y(0) == OUTSIDE)
return true;
if (y(1) == INSIDE && y(0) == UNIT)
return true;
if (y(1) == OUTSIDE && y(0) == INSIDE)
return true;
if (y(1) == OUTSIDE && y(0) == LAST)
return true;
if (y(1) == LAST && y(0) == INSIDE)
return true;
if (y(1) == LAST && y(0) == LAST)
return true;
if (y(1) == UNIT && y(0) == INSIDE)
return true;
if (y(1) == UNIT && y(0) == LAST)
return true;
// if at the end of the sequence
if (pos == x.size()-1)
{
if (y(0) == BEGIN)
return true;
if (y(0) == INSIDE)
return true;
}
}
else
{
if (y(0) == INSIDE)
return true;
if (y(0) == LAST)
return true;
// if at the end of the sequence
if (pos == x.size()-1)
{
if (y(0) == BEGIN)
return true;
}
}
}
return false;
}
template <typename feature_setter, typename EXP>
void get_features (
feature_setter& set_feature,
const sequence_type& x,
const matrix_exp<EXP>& y,
unsigned long position
) const
{
unsigned long offset = 0;
const int window_size = fe.window_size();
const int base_dims = fe.num_features();
for (int i = 0; i < window_size; ++i)
{
const long pos = i-window_size/2 + static_cast<long>(position);
if (0 <= pos && pos < (long)x.size())
{
const unsigned long off1 = y(0)*base_dims;
dot_functor<feature_setter> fs1(set_feature, offset+off1);
fe.get_features(fs1, x, pos);
if (ss_feature_extractor::use_high_order_features && y.size() > 1)
{
const unsigned long off2 = num_labels()*base_dims + (y(0)*num_labels()+y(1))*base_dims;
dot_functor<feature_setter> fs2(set_feature, offset+off2);
fe.get_features(fs2, x, pos);
}
}
if (ss_feature_extractor::use_high_order_features)
offset += num_labels()*base_dims + num_labels()*num_labels()*base_dims;
else
offset += num_labels()*base_dims;
}
// Pull out an indicator feature for the type of transition between the
// previous label and the current label.
if (y.size() > 1)
set_feature(offset + y(1)*num_labels() + y(0));
offset += num_labels()*num_labels();
// pull out an indicator feature for the current label. This is the per
// label bias.
set_feature(offset + y(0));
}
};
} // end namespace impl_ss
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
unsigned long total_feature_vector_size (
const feature_extractor& fe
)
{
const unsigned long NL = feature_extractor::use_BIO_model ? 3 : 5;
if (feature_extractor::use_high_order_features)
return NL + NL*NL + (NL*NL+NL)*fe.num_features()*fe.window_size();
else
return NL + NL*NL + NL*fe.num_features()*fe.window_size();
}
// ----------------------------------------------------------------------------------------
template <
typename feature_extractor
>
class sequence_segmenter
{
public:
typedef typename feature_extractor::sequence_type sample_sequence_type;
typedef std::vector<std::pair<unsigned long, unsigned long> > segmented_sequence_type;
sequence_segmenter()
{
#ifdef ENABLE_ASSERTS
const feature_extractor& fe = labeler.get_feature_extractor().fe;
DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
"\t sequence_segmenter::sequence_segmenter()"
<< "\n\t An invalid feature extractor was supplied."
<< "\n\t fe.window_size(): " << fe.window_size()
<< "\n\t fe.num_features(): " << fe.num_features()
<< "\n\t this: " << this
);
#endif
}
explicit sequence_segmenter(
const matrix<double,0,1>& weights
) :
labeler(weights)
{
#ifdef ENABLE_ASSERTS
const feature_extractor& fe = labeler.get_feature_extractor().fe;
// make sure requires clause is not broken
DLIB_ASSERT(total_feature_vector_size(fe) == (unsigned long)weights.size(),
"\t sequence_segmenter::sequence_segmenter(weights)"
<< "\n\t These sizes should match"
<< "\n\t total_feature_vector_size(fe): " << total_feature_vector_size(fe)
<< "\n\t weights.size(): " << weights.size()
<< "\n\t this: " << this
);
DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
"\t sequence_segmenter::sequence_segmenter()"
<< "\n\t An invalid feature extractor was supplied."
<< "\n\t fe.window_size(): " << fe.window_size()
<< "\n\t fe.num_features(): " << fe.num_features()
<< "\n\t this: " << this
);
#endif
}
sequence_segmenter(
const matrix<double,0,1>& weights,
const feature_extractor& fe
) :
labeler(weights, impl_ss::feature_extractor<feature_extractor>(fe))
{
// make sure requires clause is not broken
DLIB_ASSERT(total_feature_vector_size(fe) == (unsigned long)weights.size(),
"\t sequence_segmenter::sequence_segmenter(weights,fe)"
<< "\n\t These sizes should match"
<< "\n\t total_feature_vector_size(fe): " << total_feature_vector_size(fe)
<< "\n\t weights.size(): " << weights.size()
<< "\n\t this: " << this
);
DLIB_ASSERT(fe.window_size() >= 1 && fe.num_features() >= 1,
"\t sequence_segmenter::sequence_segmenter()"
<< "\n\t An invalid feature extractor was supplied."
<< "\n\t fe.window_size(): " << fe.window_size()
<< "\n\t fe.num_features(): " << fe.num_features()
<< "\n\t this: " << this
);
}
const feature_extractor& get_feature_extractor (
) const { return labeler.get_feature_extractor().fe; }
const matrix<double,0,1>& get_weights (
) const { return labeler.get_weights(); }
segmented_sequence_type operator() (
const sample_sequence_type& x
) const
{
segmented_sequence_type y;
segment_sequence(x,y);
return y;
}
void segment_sequence (
const sample_sequence_type& x,
segmented_sequence_type& y
) const
{
y.clear();
std::vector<unsigned long> labels;
labeler.label_sequence(x, labels);
if (feature_extractor::use_BIO_model)
{
// Convert from BIO tagging to the explicit segments representation.
for (unsigned long i = 0; i < labels.size(); ++i)
{
if (labels[i] == impl_ss::BEGIN)
{
const unsigned long begin = i;
++i;
while (i < labels.size() && labels[i] == impl_ss::INSIDE)
++i;
y.push_back(std::make_pair(begin, i));
--i;
}
}
}
else
{
// Convert from BILOU tagging to the explicit segments representation.
for (unsigned long i = 0; i < labels.size(); ++i)
{
if (labels[i] == impl_ss::BEGIN)
{
const unsigned long begin = i;
++i;
while (i < labels.size() && labels[i] == impl_ss::INSIDE)
++i;
y.push_back(std::make_pair(begin, i+1));
}
else if (labels[i] == impl_ss::UNIT)
{
y.push_back(std::make_pair(i, i+1));
}
}
}
}
friend void serialize(const sequence_segmenter& item, std::ostream& out)
{
int version = 1;
serialize(version, out);
// Save these just so we can compare them when we deserialize and make
// sure the feature_extractor being used is compatible with the model being
// loaded.
serialize(feature_extractor::use_BIO_model, out);
serialize(feature_extractor::use_high_order_features, out);
serialize(total_feature_vector_size(item.get_feature_extractor()), out);
serialize(item.labeler, out);
}
friend void deserialize(sequence_segmenter& item, std::istream& in)
{
int version = 0;
deserialize(version, in);
if (version != 1)
throw serialization_error("Unexpected version found while deserializing dlib::sequence_segmenter.");
// Try to check if the saved model is compatible with the current feature
// extractor.
bool use_BIO_model, use_high_order_features;
unsigned long dims;
deserialize(use_BIO_model, in);
deserialize(use_high_order_features, in);
deserialize(dims, in);
deserialize(item.labeler, in);
if (use_BIO_model != feature_extractor::use_BIO_model)
{
throw serialization_error("Incompatible feature extractor found while deserializing "
"dlib::sequence_segmenter. Wrong value of use_BIO_model.");
}
if (use_high_order_features != feature_extractor::use_high_order_features)
{
throw serialization_error("Incompatible feature extractor found while deserializing "
"dlib::sequence_segmenter. Wrong value of use_high_order_features.");
}
if (dims != total_feature_vector_size(item.get_feature_extractor()))
{
throw serialization_error("Incompatible feature extractor found while deserializing "
"dlib::sequence_segmenter. Wrong value of total_feature_vector_size().");
}
}
private:
sequence_labeler<impl_ss::feature_extractor<feature_extractor> > labeler;
};
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_SEQUENCE_SeGMENTER_H_h_