294 lines
9.2 KiB
C++
294 lines
9.2 KiB
C++
// Copyright (C) 2013 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
|
|
|
|
#include <sstream>
|
|
#include "tester.h"
|
|
#include <dlib/svm_threaded.h>
|
|
#include <dlib/rand.h>
|
|
|
|
|
|
namespace
|
|
{
|
|
using namespace test;
|
|
using namespace dlib;
|
|
using namespace std;
|
|
|
|
logger dlog("test.sequence_segmenter");
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
dlib::rand rnd;
|
|
|
|
template <bool use_BIO_model_, bool use_high_order_features_, bool allow_negative_weights_>
|
|
class unigram_extractor
|
|
{
|
|
public:
|
|
|
|
const static bool use_BIO_model = use_BIO_model_;
|
|
const static bool use_high_order_features = use_high_order_features_;
|
|
const static bool allow_negative_weights = allow_negative_weights_;
|
|
|
|
typedef std::vector<unsigned long> sequence_type;
|
|
|
|
std::map<unsigned long, matrix<double,0,1> > feats;
|
|
|
|
unigram_extractor()
|
|
{
|
|
matrix<double,0,1> v1, v2, v3;
|
|
v1 = randm(num_features(), 1, rnd);
|
|
v2 = randm(num_features(), 1, rnd);
|
|
v3 = randm(num_features(), 1, rnd);
|
|
v1(0) = 1;
|
|
v2(1) = 1;
|
|
v3(2) = 1;
|
|
v1(3) = -1;
|
|
v2(4) = -1;
|
|
v3(5) = -1;
|
|
for (unsigned long i = 0; i < num_features(); ++i)
|
|
{
|
|
if ( i < 3)
|
|
feats[i] = v1;
|
|
else if (i < 6)
|
|
feats[i] = v2;
|
|
else
|
|
feats[i] = v3;
|
|
}
|
|
}
|
|
|
|
unsigned long num_features() const { return 10; }
|
|
unsigned long window_size() const { return 3; }
|
|
|
|
template <typename feature_setter>
|
|
void get_features (
|
|
feature_setter& set_feature,
|
|
const sequence_type& x,
|
|
unsigned long position
|
|
) const
|
|
{
|
|
const matrix<double,0,1>& m = feats.find(x[position])->second;
|
|
for (unsigned long i = 0; i < num_features(); ++i)
|
|
{
|
|
set_feature(i, m(i));
|
|
}
|
|
}
|
|
|
|
};
|
|
|
|
template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
|
|
void serialize(const unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item , std::ostream& out )
|
|
{
|
|
serialize(item.feats, out);
|
|
}
|
|
|
|
template <bool use_BIO_model_, bool use_high_order_features_, bool neg>
|
|
void deserialize(unigram_extractor<use_BIO_model_,use_high_order_features_,neg>& item, std::istream& in)
|
|
{
|
|
deserialize(item.feats, in);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
void make_dataset (
|
|
std::vector<std::vector<unsigned long> >& samples,
|
|
std::vector<std::vector<unsigned long> >& labels,
|
|
unsigned long dataset_size
|
|
)
|
|
{
|
|
samples.clear();
|
|
labels.clear();
|
|
|
|
samples.resize(dataset_size);
|
|
labels.resize(dataset_size);
|
|
|
|
|
|
unigram_extractor<true,true,true> fe;
|
|
dlib::rand rnd;
|
|
|
|
for (unsigned long iter = 0; iter < dataset_size; ++iter)
|
|
{
|
|
|
|
samples[iter].resize(10);
|
|
labels[iter].resize(10);
|
|
|
|
for (unsigned long i = 0; i < samples[iter].size(); ++i)
|
|
{
|
|
samples[iter][i] = rnd.get_random_32bit_number()%fe.num_features();
|
|
if (samples[iter][i] < 3)
|
|
{
|
|
labels[iter][i] = impl_ss::BEGIN;
|
|
}
|
|
else if (samples[iter][i] < 6)
|
|
{
|
|
labels[iter][i] = impl_ss::INSIDE;
|
|
}
|
|
else
|
|
{
|
|
labels[iter][i] = impl_ss::OUTSIDE;
|
|
}
|
|
|
|
if (i != 0)
|
|
{
|
|
// do rejection sampling to avoid impossible labels
|
|
if (labels[iter][i] == impl_ss::INSIDE &&
|
|
labels[iter][i-1] == impl_ss::OUTSIDE)
|
|
{
|
|
--i;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
void make_dataset2 (
|
|
std::vector<std::vector<unsigned long> >& samples,
|
|
std::vector<std::vector<std::pair<unsigned long, unsigned long> > >& segments,
|
|
unsigned long dataset_size
|
|
)
|
|
{
|
|
segments.clear();
|
|
std::vector<std::vector<unsigned long> > labels;
|
|
make_dataset(samples, labels, dataset_size);
|
|
segments.resize(samples.size());
|
|
|
|
// Convert from BIO tagging to the explicit segments representation.
|
|
for (unsigned long k = 0; k < labels.size(); ++k)
|
|
{
|
|
for (unsigned long i = 0; i < labels[k].size(); ++i)
|
|
{
|
|
if (labels[k][i] == impl_ss::BEGIN)
|
|
{
|
|
const unsigned long begin = i;
|
|
++i;
|
|
while (i < labels[k].size() && labels[k][i] == impl_ss::INSIDE)
|
|
++i;
|
|
|
|
segments[k].push_back(std::make_pair(begin, i));
|
|
--i;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <bool use_BIO_model, bool use_high_order_features, bool allow_negative_weights>
|
|
void do_test()
|
|
{
|
|
dlog << LINFO << "use_BIO_model: "<< use_BIO_model;
|
|
dlog << LINFO << "use_high_order_features: "<< use_high_order_features;
|
|
dlog << LINFO << "allow_negative_weights: "<< allow_negative_weights;
|
|
|
|
std::vector<std::vector<unsigned long> > samples;
|
|
std::vector<std::vector<std::pair<unsigned long,unsigned long> > > segments;
|
|
make_dataset2( samples, segments, 100);
|
|
|
|
print_spinner();
|
|
typedef unigram_extractor<use_BIO_model,use_high_order_features,allow_negative_weights> fe_type;
|
|
|
|
fe_type fe_temp;
|
|
fe_type fe_temp2;
|
|
structural_sequence_segmentation_trainer<fe_type> trainer(fe_temp2);
|
|
trainer.set_c(5);
|
|
trainer.set_num_threads(1);
|
|
|
|
|
|
sequence_segmenter<fe_type> labeler = trainer.train(samples, segments);
|
|
|
|
print_spinner();
|
|
|
|
const std::vector<std::pair<unsigned long, unsigned long> > predicted_labels = labeler(samples[1]);
|
|
const std::vector<std::pair<unsigned long, unsigned long> > true_labels = segments[1];
|
|
/*
|
|
for (unsigned long i = 0; i < predicted_labels.size(); ++i)
|
|
cout << "["<<predicted_labels[i].first<<","<<predicted_labels[i].second<<") ";
|
|
cout << endl;
|
|
for (unsigned long i = 0; i < true_labels.size(); ++i)
|
|
cout << "["<<true_labels[i].first<<","<<true_labels[i].second<<") ";
|
|
cout << endl;
|
|
*/
|
|
|
|
DLIB_TEST(predicted_labels.size() > 0);
|
|
DLIB_TEST(predicted_labels.size() == true_labels.size());
|
|
for (unsigned long i = 0; i < predicted_labels.size(); ++i)
|
|
{
|
|
DLIB_TEST(predicted_labels[i].first == true_labels[i].first);
|
|
DLIB_TEST(predicted_labels[i].second == true_labels[i].second);
|
|
}
|
|
|
|
|
|
matrix<double> res;
|
|
|
|
res = cross_validate_sequence_segmenter(trainer, samples, segments, 3);
|
|
dlog << LINFO << "cv res: "<< res;
|
|
DLIB_TEST(min(res) > 0.98);
|
|
make_dataset2( samples, segments, 100);
|
|
res = test_sequence_segmenter(labeler, samples, segments);
|
|
dlog << LINFO << "test res: "<< res;
|
|
DLIB_TEST(min(res) > 0.98);
|
|
|
|
print_spinner();
|
|
|
|
ostringstream sout;
|
|
serialize(labeler, sout);
|
|
istringstream sin(sout.str());
|
|
sequence_segmenter<fe_type> labeler2;
|
|
deserialize(labeler2, sin);
|
|
|
|
res = test_sequence_segmenter(labeler2, samples, segments);
|
|
dlog << LINFO << "test res2: "<< res;
|
|
DLIB_TEST(min(res) > 0.98);
|
|
|
|
long N;
|
|
if (use_BIO_model)
|
|
N = 3*3+3;
|
|
else
|
|
N = 5*5+5;
|
|
const double min_normal_weight = min(colm(labeler2.get_weights(), 0, labeler2.get_weights().size()-N));
|
|
const double min_trans_weight = min(labeler2.get_weights());
|
|
dlog << LINFO << "min_normal_weight: " << min_normal_weight;
|
|
dlog << LINFO << "min_trans_weight: " << min_trans_weight;
|
|
if (allow_negative_weights)
|
|
{
|
|
DLIB_TEST(min_normal_weight < 0);
|
|
DLIB_TEST(min_trans_weight < 0);
|
|
}
|
|
else
|
|
{
|
|
DLIB_TEST(min_normal_weight == 0);
|
|
DLIB_TEST(min_trans_weight < 0);
|
|
}
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
|
|
class unit_test_sequence_segmenter : public tester
|
|
{
|
|
public:
|
|
unit_test_sequence_segmenter (
|
|
) :
|
|
tester ("test_sequence_segmenter",
|
|
"Runs tests on the sequence segmenting code.")
|
|
{}
|
|
|
|
void perform_test (
|
|
)
|
|
{
|
|
do_test<true,true,false>();
|
|
do_test<true,false,false>();
|
|
do_test<false,true,false>();
|
|
do_test<false,false,false>();
|
|
do_test<true,true,true>();
|
|
do_test<true,false,true>();
|
|
do_test<false,true,true>();
|
|
do_test<false,false,true>();
|
|
}
|
|
} a;
|
|
|
|
}
|
|
|
|
|
|
|