149 lines
5.1 KiB
C++
149 lines
5.1 KiB
C++
// Copyright (C) 2009 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
#ifndef DLIB_ROC_TRAINEr_H_
|
|
#define DLIB_ROC_TRAINEr_H_
|
|
|
|
#include "roc_trainer_abstract.h"
|
|
#include "../algs.h"
|
|
#include <limits>
|
|
|
|
namespace dlib
|
|
{
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type
|
|
>
|
|
class roc_trainer_type
|
|
{
|
|
public:
|
|
typedef typename trainer_type::kernel_type kernel_type;
|
|
typedef typename trainer_type::scalar_type scalar_type;
|
|
typedef typename trainer_type::sample_type sample_type;
|
|
typedef typename trainer_type::mem_manager_type mem_manager_type;
|
|
typedef typename trainer_type::trained_function_type trained_function_type;
|
|
|
|
roc_trainer_type (
|
|
) : desired_accuracy(0), class_selection(0){}
|
|
|
|
roc_trainer_type (
|
|
const trainer_type& trainer_,
|
|
const scalar_type& desired_accuracy_,
|
|
const scalar_type& class_selection_
|
|
) : trainer(trainer_), desired_accuracy(desired_accuracy_), class_selection(class_selection_)
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(0 <= desired_accuracy && desired_accuracy <= 1 &&
|
|
(class_selection == -1 || class_selection == +1),
|
|
"\t roc_trainer_type::roc_trainer_type()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
<< "\n\t desired_accuracy: " << desired_accuracy
|
|
<< "\n\t class_selection: " << class_selection
|
|
);
|
|
}
|
|
|
|
template <
|
|
typename in_sample_vector_type,
|
|
typename in_scalar_vector_type
|
|
>
|
|
const trained_function_type train (
|
|
const in_sample_vector_type& samples,
|
|
const in_scalar_vector_type& labels
|
|
) const
|
|
/*!
|
|
requires
|
|
- is_binary_classification_problem(samples, labels) == true
|
|
!*/
|
|
{
|
|
// make sure requires clause is not broken
|
|
DLIB_ASSERT(is_binary_classification_problem(samples, labels),
|
|
"\t roc_trainer_type::train()"
|
|
<< "\n\t invalid inputs were given to this function"
|
|
);
|
|
|
|
|
|
return do_train(mat(samples), mat(labels));
|
|
}
|
|
|
|
private:
|
|
|
|
template <
|
|
typename in_sample_vector_type,
|
|
typename in_scalar_vector_type
|
|
>
|
|
const trained_function_type do_train (
|
|
const in_sample_vector_type& samples,
|
|
const in_scalar_vector_type& labels
|
|
) const
|
|
{
|
|
trained_function_type df = trainer.train(samples, labels);
|
|
|
|
// clear out the old bias
|
|
df.b = 0;
|
|
|
|
// obtain all the scores from the df using all the class_selection labeled samples
|
|
std::vector<double> scores;
|
|
for (long i = 0; i < samples.size(); ++i)
|
|
{
|
|
if (labels(i) == class_selection)
|
|
scores.push_back(df(samples(i)));
|
|
}
|
|
|
|
if (class_selection == +1)
|
|
std::sort(scores.rbegin(), scores.rend());
|
|
else
|
|
std::sort(scores.begin(), scores.end());
|
|
|
|
// now pick out the index that gives us the desired accuracy with regards to selected class
|
|
unsigned long idx = static_cast<unsigned long>(desired_accuracy*scores.size() + 0.5);
|
|
if (idx >= scores.size())
|
|
idx = scores.size()-1;
|
|
|
|
df.b = scores[idx];
|
|
|
|
// In this case add a very small extra amount to the bias so that all the samples
|
|
// with the class_selection label are classified correctly.
|
|
if (desired_accuracy == 1)
|
|
{
|
|
if (class_selection == +1)
|
|
df.b -= std::numeric_limits<scalar_type>::epsilon()*df.b;
|
|
else
|
|
df.b += std::numeric_limits<scalar_type>::epsilon()*df.b;
|
|
}
|
|
|
|
return df;
|
|
}
|
|
|
|
trainer_type trainer;
|
|
scalar_type desired_accuracy;
|
|
scalar_type class_selection;
|
|
};
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type
|
|
>
|
|
const roc_trainer_type<trainer_type> roc_c1_trainer (
|
|
const trainer_type& trainer,
|
|
const typename trainer_type::scalar_type& desired_accuracy
|
|
) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, +1); }
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <
|
|
typename trainer_type
|
|
>
|
|
const roc_trainer_type<trainer_type> roc_c2_trainer (
|
|
const trainer_type& trainer,
|
|
const typename trainer_type::scalar_type& desired_accuracy
|
|
) { return roc_trainer_type<trainer_type>(trainer, desired_accuracy, -1); }
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
}
|
|
|
|
#endif // DLIB_ROC_TRAINEr_H_
|
|
|
|
|