sustaining_gazes/lib/3rdParty/dlib/include/dlib/test/bsp.cpp
2016-04-28 15:40:36 -04:00

566 lines
16 KiB
C++

// Copyright (C) 2012 Davis E. King (davis@dlib.net)
// License: Boost Software License See LICENSE.txt for the full license.
#include <dlib/bsp.h>
#include <dlib/threads.h>
#include <dlib/pipe.h>
#include <dlib/matrix.h>
#include "tester.h"
namespace
{
using namespace test;
using namespace dlib;
using namespace std;
logger dlog("test.bsp");
template <typename funct>
struct callfunct_helper
{
callfunct_helper (
funct f_,
int port_,
bool& error_occurred_
) :f(f_), port(port_), error_occurred(error_occurred_) {}
funct f;
int port;
bool& error_occurred;
void operator() (
) const
{
try
{
bsp_listen(port, f);
}
catch (exception& e)
{
dlog << LERROR << "error calling bsp_listen(): " << e.what();
error_occurred = true;
}
}
};
template <typename funct>
callfunct_helper<funct> callfunct(funct f, int port, bool& error_occurred)
{
return callfunct_helper<funct>(f,port,error_occurred);
}
// ----------------------------------------------------------------------------------------
template <typename funct>
struct callfunct_helper_pn
{
callfunct_helper_pn (
funct f_,
int port_,
bool& error_occurred_,
dlib::pipe<unsigned short>& port_pipe_
) :f(f_), port(port_), error_occurred(error_occurred_), port_pipe(port_pipe_) {}
funct f;
int port;
bool& error_occurred;
dlib::pipe<unsigned short>& port_pipe;
struct helper
{
helper (
dlib::pipe<unsigned short>& port_pipe_
) : port_pipe(port_pipe_) {}
dlib::pipe<unsigned short>& port_pipe;
void operator() (unsigned short p) { port_pipe.enqueue(p); }
};
void operator() (
) const
{
try
{
bsp_listen_dynamic_port(port, helper(port_pipe), f);
}
catch (exception& e)
{
dlog << LERROR << "error calling bsp_listen_dynamic_port(): " << e.what();
error_occurred = true;
}
}
};
template <typename funct>
callfunct_helper_pn<funct> callfunct(funct f, int port, bool& error_occurred, dlib::pipe<unsigned short>& port_pipe)
{
return callfunct_helper_pn<funct>(f,port,error_occurred,port_pipe);
}
// ----------------------------------------------------------------------------------------
void sum_array_driver (
bsp_context& obj,
const std::vector<int>& v,
int& result
)
{
obj.broadcast(v);
result = 0;
int val;
while(obj.try_receive(val))
result += val;
}
void sum_array_other (
bsp_context& obj
)
{
std::vector<int> v;
obj.receive(v);
int sum = 0;
for (unsigned long i = 0; i < v.size(); ++i)
sum += v[i];
obj.send(sum, 0);
}
void dotest1()
{
dlog << LINFO << "start dotest1()";
print_spinner();
bool error_occurred = false;
{
thread_function t1(callfunct(sum_array_other, 12345, error_occurred));
thread_function t2(callfunct(sum_array_other, 12346, error_occurred));
thread_function t3(callfunct(sum_array_other, 12347, error_occurred));
std::vector<int> v;
int true_value = 0;
for (int i = 0; i < 10; ++i)
{
v.push_back(i);
true_value += i;
}
// wait a little bit for the threads to start up
dlib::sleep(200);
try
{
int result;
std::vector<network_address> hosts;
hosts.push_back("127.0.0.1:12345");
hosts.push_back("localhost:12346");
hosts.push_back("127.0.0.1:12347");
bsp_connect(hosts, sum_array_driver, dlib::ref(v), dlib::ref(result));
dlog << LINFO << "result: "<< result;
dlog << LINFO << "should be: "<< 3*true_value;
DLIB_TEST(result == 3*true_value);
}
catch (std::exception& e)
{
dlog << LERROR << "error during bsp_context: " << e.what();
DLIB_TEST(false);
}
}
DLIB_TEST(error_occurred == false);
}
// ----------------------------------------------------------------------------------------
template <unsigned long id>
void test2_job(bsp_context& obj)
{
if (obj.node_id() == id)
dlib::sleep(100);
}
template <unsigned long id>
void dotest2()
{
dlog << LINFO << "start dotest2()";
print_spinner();
bool error_occurred = false;
{
thread_function t1(callfunct(test2_job<id>, 12345, error_occurred));
thread_function t2(callfunct(test2_job<id>, 12346, error_occurred));
thread_function t3(callfunct(test2_job<id>, 12347, error_occurred));
// wait a little bit for the threads to start up
dlib::sleep(200);
try
{
std::vector<network_address> hosts;
hosts.push_back("127.0.0.1:12345");
hosts.push_back("127.0.0.1:12346");
hosts.push_back("127.0.0.1:12347");
bsp_connect(hosts, test2_job<id>);
}
catch (std::exception& e)
{
dlog << LERROR << "error during bsp_context: " << e.what();
DLIB_TEST(false);
}
}
DLIB_TEST(error_occurred == false);
}
// ----------------------------------------------------------------------------------------
void test3_job_driver(bsp_context& obj, int& result)
{
obj.broadcast(obj.node_id());
int accum = 0;
int temp = 0;
while(obj.try_receive(temp))
accum += temp;
// send to node 1 so it can sum everything
if (obj.node_id() != 1)
obj.send(accum, 1);
while(obj.try_receive(temp))
accum += temp;
// Now hop the accum values along the nodes until the value from node 1 gets to
// node 0.
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
obj.receive(accum);
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
obj.receive(accum);
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
obj.receive(accum);
// this whole block is a noop since it doesn't end up doing anything.
for (int k = 0; k < 100; ++k)
{
dlog << LINFO << "k: " << k;
for (int i = 0; i < 4; ++i)
{
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
obj.receive(accum);
}
}
dlog << LINFO << "TERMINATE";
if (obj.node_id() == 0)
result = accum;
}
void test3_job(bsp_context& obj)
{
int junk;
test3_job_driver(obj, junk);
}
void dotest3()
{
dlog << LINFO << "start dotest3()";
print_spinner();
bool error_occurred = false;
{
dlib::pipe<unsigned short> ports(5);
thread_function t1(callfunct(test3_job, 12345, error_occurred, ports));
thread_function t2(callfunct(test3_job, 0, error_occurred, ports));
thread_function t3(callfunct(test3_job, 12347, error_occurred, ports));
try
{
std::vector<network_address> hosts;
unsigned short port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
int result = 0;
const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2;
bsp_connect(hosts, test3_job_driver, dlib::ref(result));
dlog << LINFO << "result: " << result;
dlog << LINFO << "should be: " << expected;
DLIB_TEST(result == expected);
}
catch (std::exception& e)
{
dlog << LERROR << "error during bsp_context: " << e.what();
DLIB_TEST(false);
}
}
DLIB_TEST(error_occurred == false);
}
// ----------------------------------------------------------------------------------------
void test4_job_driver(bsp_context& obj, int& result)
{
obj.broadcast(obj.node_id());
int accum = 0;
int temp = 0;
while(obj.try_receive(temp))
accum += temp;
// send to node 1 so it can sum everything
if (obj.node_id() != 1)
obj.send(accum, 1);
while(obj.try_receive(temp))
accum += temp;
// Now hop the accum values along the nodes until the value from node 1 gets to
// node 0.
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
obj.receive(accum);
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
obj.receive(accum);
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
obj.receive(accum);
// this whole block is a noop since it doesn't end up doing anything.
for (int k = 0; k < 40; ++k)
{
dlog << LINFO << "k: " << k;
for (int i = 0; i < 4; ++i)
{
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
obj.receive(accum);
obj.receive();
}
}
dlog << LINFO << "TERMINATE";
if (obj.node_id() == 0)
result = accum;
}
void test4_job(bsp_context& obj)
{
int junk;
test4_job_driver(obj, junk);
}
void dotest4()
{
dlog << LINFO << "start dotest4()";
print_spinner();
bool error_occurred = false;
{
dlib::pipe<unsigned short> ports(5);
thread_function t1(callfunct(test4_job, 0, error_occurred, ports));
thread_function t2(callfunct(test4_job, 0, error_occurred, ports));
thread_function t3(callfunct(test4_job, 0, error_occurred, ports));
try
{
std::vector<network_address> hosts;
unsigned short port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
int result = 0;
const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2;
bsp_connect(hosts, test4_job_driver, dlib::ref(result));
dlog << LINFO << "result: " << result;
dlog << LINFO << "should be: " << expected;
DLIB_TEST(result == expected);
}
catch (std::exception& e)
{
dlog << LERROR << "error during bsp_context: " << e.what();
DLIB_TEST(false);
}
}
DLIB_TEST(error_occurred == false);
}
// ----------------------------------------------------------------------------------------
void test5_job(
bsp_context& ,
int& val
)
{
val = 25;
}
void dotest5()
{
dlog << LINFO << "start dotest5()";
print_spinner();
std::vector<network_address> hosts;
int val = 0;
bsp_connect(hosts, test5_job, dlib::ref(val));
DLIB_TEST(val == 25);
}
// ----------------------------------------------------------------------------------------
double f ( double x)
{
return std::pow(x-2.0, 2.0);
}
void bsp_job_node_0 (
bsp_context& context,
double& min_value,
double& optimal_x
)
{
double left = -100;
double right = 100;
min_value = std::numeric_limits<double>::infinity();
double interval_width = std::abs(right-left);
// This is doing a BSP based grid search for the minimum of f(). Here we
// do 100 iterations where we keep shrinking the grid size.
for (int i = 0; i < 100; ++i)
{
context.broadcast(left);
context.broadcast(right);
for (unsigned int k = 1; k < context.number_of_nodes(); ++k)
{
std::pair<double,double> val;
context.receive(val);
if (val.second < min_value)
{
min_value = val.second;
optimal_x = val.first;
}
}
interval_width *= 0.5;
left = optimal_x - interval_width/2;
right = optimal_x + interval_width/2;
}
}
void bsp_job_other_nodes (
bsp_context& context
)
{
double left, right;
while (context.try_receive(left))
{
context.receive(right);
const double l = (context.node_id()-1)/(context.number_of_nodes()-1.0);
const double r = context.node_id() /(context.number_of_nodes()-1.0);
const double width = right-left;
matrix<double> values_to_check = linspace(left +l*width, left + r*width, 100);
double best_x = 0;
double best_val = std::numeric_limits<double>::infinity();
for (long j = 0; j < values_to_check.size(); ++j)
{
double temp = f(values_to_check(j));
if (temp < best_val)
{
best_val = temp;
best_x = values_to_check(j);
}
}
context.send(make_pair(best_x, best_val), 0);
}
}
void dotest6()
{
dlog << LINFO << "start dotest6()";
print_spinner();
bool error_occurred = false;
{
dlib::pipe<unsigned short> ports(5);
thread_function t1(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
thread_function t2(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
thread_function t3(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
try
{
std::vector<network_address> hosts;
unsigned short port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
double min_value = 10, optimal_x = 0;
bsp_connect(hosts, bsp_job_node_0, dlib::ref(min_value), dlib::ref(optimal_x));
dlog << LINFO << "min_value: " << min_value;
dlog << LINFO << "optimal_x: " << optimal_x;
DLIB_TEST(std::abs(min_value - 0) < 1e-14);
DLIB_TEST(std::abs(optimal_x - 2) < 1e-14);
}
catch (std::exception& e)
{
dlog << LERROR << "error during bsp_context: " << e.what();
DLIB_TEST(false);
}
}
DLIB_TEST(error_occurred == false);
}
// ----------------------------------------------------------------------------------------
class bsp_tester : public tester
{
public:
bsp_tester (
) :
tester ("test_bsp",
"Runs tests on the BSP components.")
{}
void perform_test (
)
{
for (int i = 0; i < 3; ++i)
{
dotest1();
dotest2<0>();
dotest2<1>();
dotest2<2>();
dotest3();
dotest4();
dotest5();
dotest6();
}
}
} a;
}