sustaining_gazes/lib/3rdParty/dlib/include/dlib/svm/roc_trainer.h
2016-04-28 15:40:36 -04:00

149 lines
5.1 KiB
C++

// Copyright (C) 2009 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#ifndef DLIB_ROC_TRAINEr_H_
#define DLIB_ROC_TRAINEr_H_
#include "roc_trainer_abstract.h"
#include "../algs.h"
#include <limits>
namespace dlib
{
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
class roc_trainer_type
{
public:
typedef typename trainer_type::kernel_type kernel_type;
typedef typename trainer_type::scalar_type scalar_type;
typedef typename trainer_type::sample_type sample_type;
typedef typename trainer_type::mem_manager_type mem_manager_type;
typedef typename trainer_type::trained_function_type trained_function_type;
roc_trainer_type (
) : desired_accuracy(0), class_selection(0){}
roc_trainer_type (
const trainer_type& trainer_,
const scalar_type& desired_accuracy_,
const scalar_type& class_selection_
) : trainer(trainer_), desired_accuracy(desired_accuracy_), class_selection(class_selection_)
{
// make sure requires clause is not broken
DLIB_ASSERT(0 <= desired_accuracy && desired_accuracy <= 1 &&
(class_selection == -1 || class_selection == +1),
"\t roc_trainer_type::roc_trainer_type()"
<< "\n\t invalid inputs were given to this function"
<< "\n\t desired_accuracy: " << desired_accuracy
<< "\n\t class_selection: " << class_selection
);
}
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const trained_function_type train (
const in_sample_vector_type& samples,
const in_scalar_vector_type& labels
) const
/*!
requires
- is_binary_classification_problem(samples, labels) == true
!*/
{
// make sure requires clause is not broken
DLIB_ASSERT(is_binary_classification_problem(samples, labels),
"\t roc_trainer_type::train()"
<< "\n\t invalid inputs were given to this function"
);
return do_train(mat(samples), mat(labels));
}
private:
template <
typename in_sample_vector_type,
typename in_scalar_vector_type
>
const trained_function_type do_train (
const in_sample_vector_type& samples,
const in_scalar_vector_type& labels
) const
{
trained_function_type df = trainer.train(samples, labels);
// clear out the old bias
df.b = 0;
// obtain all the scores from the df using all the class_selection labeled samples
std::vector<double> scores;
for (long i = 0; i < samples.size(); ++i)
{
if (labels(i) == class_selection)
scores.push_back(df(samples(i)));
}
if (class_selection == +1)
std::sort(scores.rbegin(), scores.rend());
else
std::sort(scores.begin(), scores.end());
// now pick out the index that gives us the desired accuracy with regards to selected class
unsigned long idx = static_cast<unsigned long>(desired_accuracy*scores.size() + 0.5);
if (idx >= scores.size())
idx = scores.size()-1;
df.b = scores[idx];
// In this case add a very small extra amount to the bias so that all the samples
// with the class_selection label are classified correctly.
if (desired_accuracy == 1)
{
if (class_selection == +1)
df.b -= std::numeric_limits<scalar_type>::epsilon()*df.b;
else
df.b += std::numeric_limits<scalar_type>::epsilon()*df.b;
}
return df;
}
trainer_type trainer;
scalar_type desired_accuracy;
scalar_type class_selection;
};
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const roc_trainer_type<trainer_type> roc_c1_trainer (
const trainer_type& trainer,
const typename trainer_type::scalar_type& desired_accuracy
) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); }
// ----------------------------------------------------------------------------------------
template <
typename trainer_type
>
const roc_trainer_type<trainer_type> roc_c2_trainer (
const trainer_type& trainer,
const typename trainer_type::scalar_type& desired_accuracy
) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); }
// ----------------------------------------------------------------------------------------
}
#endif // DLIB_ROC_TRAINEr_H_