566 lines
16 KiB
C++
566 lines
16 KiB
C++
// Copyright (C) 2012 Davis E. King (davis@dlib.net)
|
|
// License: Boost Software License See LICENSE.txt for the full license.
|
|
|
|
|
|
#include <dlib/bsp.h>
|
|
#include <dlib/threads.h>
|
|
#include <dlib/pipe.h>
|
|
#include <dlib/matrix.h>
|
|
|
|
#include "tester.h"
|
|
|
|
namespace
|
|
{
|
|
|
|
using namespace test;
|
|
using namespace dlib;
|
|
using namespace std;
|
|
|
|
logger dlog("test.bsp");
|
|
|
|
|
|
template <typename funct>
|
|
struct callfunct_helper
|
|
{
|
|
callfunct_helper (
|
|
funct f_,
|
|
int port_,
|
|
bool& error_occurred_
|
|
) :f(f_), port(port_), error_occurred(error_occurred_) {}
|
|
|
|
funct f;
|
|
int port;
|
|
bool& error_occurred;
|
|
|
|
void operator() (
|
|
) const
|
|
{
|
|
try
|
|
{
|
|
bsp_listen(port, f);
|
|
}
|
|
catch (exception& e)
|
|
{
|
|
dlog << LERROR << "error calling bsp_listen(): " << e.what();
|
|
error_occurred = true;
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename funct>
|
|
callfunct_helper<funct> callfunct(funct f, int port, bool& error_occurred)
|
|
{
|
|
return callfunct_helper<funct>(f,port,error_occurred);
|
|
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <typename funct>
|
|
struct callfunct_helper_pn
|
|
{
|
|
callfunct_helper_pn (
|
|
funct f_,
|
|
int port_,
|
|
bool& error_occurred_,
|
|
dlib::pipe<unsigned short>& port_pipe_
|
|
) :f(f_), port(port_), error_occurred(error_occurred_), port_pipe(port_pipe_) {}
|
|
|
|
funct f;
|
|
int port;
|
|
bool& error_occurred;
|
|
dlib::pipe<unsigned short>& port_pipe;
|
|
|
|
struct helper
|
|
{
|
|
helper (
|
|
dlib::pipe<unsigned short>& port_pipe_
|
|
) : port_pipe(port_pipe_) {}
|
|
|
|
dlib::pipe<unsigned short>& port_pipe;
|
|
|
|
void operator() (unsigned short p) { port_pipe.enqueue(p); }
|
|
};
|
|
|
|
void operator() (
|
|
) const
|
|
{
|
|
try
|
|
{
|
|
bsp_listen_dynamic_port(port, helper(port_pipe), f);
|
|
}
|
|
catch (exception& e)
|
|
{
|
|
dlog << LERROR << "error calling bsp_listen_dynamic_port(): " << e.what();
|
|
error_occurred = true;
|
|
}
|
|
}
|
|
};
|
|
|
|
template <typename funct>
|
|
callfunct_helper_pn<funct> callfunct(funct f, int port, bool& error_occurred, dlib::pipe<unsigned short>& port_pipe)
|
|
{
|
|
return callfunct_helper_pn<funct>(f,port,error_occurred,port_pipe);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
void sum_array_driver (
|
|
bsp_context& obj,
|
|
const std::vector<int>& v,
|
|
int& result
|
|
)
|
|
{
|
|
obj.broadcast(v);
|
|
|
|
result = 0;
|
|
int val;
|
|
while(obj.try_receive(val))
|
|
result += val;
|
|
}
|
|
|
|
void sum_array_other (
|
|
bsp_context& obj
|
|
)
|
|
{
|
|
std::vector<int> v;
|
|
obj.receive(v);
|
|
|
|
int sum = 0;
|
|
for (unsigned long i = 0; i < v.size(); ++i)
|
|
sum += v[i];
|
|
|
|
obj.send(sum, 0);
|
|
|
|
|
|
}
|
|
|
|
|
|
void dotest1()
|
|
{
|
|
dlog << LINFO << "start dotest1()";
|
|
print_spinner();
|
|
bool error_occurred = false;
|
|
{
|
|
thread_function t1(callfunct(sum_array_other, 12345, error_occurred));
|
|
thread_function t2(callfunct(sum_array_other, 12346, error_occurred));
|
|
thread_function t3(callfunct(sum_array_other, 12347, error_occurred));
|
|
std::vector<int> v;
|
|
int true_value = 0;
|
|
for (int i = 0; i < 10; ++i)
|
|
{
|
|
v.push_back(i);
|
|
true_value += i;
|
|
}
|
|
|
|
// wait a little bit for the threads to start up
|
|
dlib::sleep(200);
|
|
|
|
try
|
|
{
|
|
int result;
|
|
std::vector<network_address> hosts;
|
|
hosts.push_back("127.0.0.1:12345");
|
|
hosts.push_back("localhost:12346");
|
|
hosts.push_back("127.0.0.1:12347");
|
|
bsp_connect(hosts, sum_array_driver, dlib::ref(v), dlib::ref(result));
|
|
|
|
dlog << LINFO << "result: "<< result;
|
|
dlog << LINFO << "should be: "<< 3*true_value;
|
|
DLIB_TEST(result == 3*true_value);
|
|
}
|
|
catch (std::exception& e)
|
|
{
|
|
dlog << LERROR << "error during bsp_context: " << e.what();
|
|
DLIB_TEST(false);
|
|
}
|
|
}
|
|
DLIB_TEST(error_occurred == false);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
template <unsigned long id>
|
|
void test2_job(bsp_context& obj)
|
|
{
|
|
if (obj.node_id() == id)
|
|
dlib::sleep(100);
|
|
}
|
|
|
|
template <unsigned long id>
|
|
void dotest2()
|
|
{
|
|
dlog << LINFO << "start dotest2()";
|
|
print_spinner();
|
|
bool error_occurred = false;
|
|
{
|
|
thread_function t1(callfunct(test2_job<id>, 12345, error_occurred));
|
|
thread_function t2(callfunct(test2_job<id>, 12346, error_occurred));
|
|
thread_function t3(callfunct(test2_job<id>, 12347, error_occurred));
|
|
|
|
// wait a little bit for the threads to start up
|
|
dlib::sleep(200);
|
|
|
|
try
|
|
{
|
|
std::vector<network_address> hosts;
|
|
hosts.push_back("127.0.0.1:12345");
|
|
hosts.push_back("127.0.0.1:12346");
|
|
hosts.push_back("127.0.0.1:12347");
|
|
bsp_connect(hosts, test2_job<id>);
|
|
}
|
|
catch (std::exception& e)
|
|
{
|
|
dlog << LERROR << "error during bsp_context: " << e.what();
|
|
DLIB_TEST(false);
|
|
}
|
|
|
|
}
|
|
DLIB_TEST(error_occurred == false);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
void test3_job_driver(bsp_context& obj, int& result)
|
|
{
|
|
|
|
obj.broadcast(obj.node_id());
|
|
|
|
int accum = 0;
|
|
int temp = 0;
|
|
while(obj.try_receive(temp))
|
|
accum += temp;
|
|
|
|
// send to node 1 so it can sum everything
|
|
if (obj.node_id() != 1)
|
|
obj.send(accum, 1);
|
|
|
|
while(obj.try_receive(temp))
|
|
accum += temp;
|
|
|
|
// Now hop the accum values along the nodes until the value from node 1 gets to
|
|
// node 0.
|
|
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
|
|
obj.receive(accum);
|
|
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
|
|
obj.receive(accum);
|
|
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
|
|
obj.receive(accum);
|
|
|
|
// this whole block is a noop since it doesn't end up doing anything.
|
|
for (int k = 0; k < 100; ++k)
|
|
{
|
|
dlog << LINFO << "k: " << k;
|
|
for (int i = 0; i < 4; ++i)
|
|
{
|
|
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
|
|
obj.receive(accum);
|
|
}
|
|
}
|
|
|
|
|
|
dlog << LINFO << "TERMINATE";
|
|
if (obj.node_id() == 0)
|
|
result = accum;
|
|
}
|
|
|
|
|
|
void test3_job(bsp_context& obj)
|
|
{
|
|
int junk;
|
|
test3_job_driver(obj, junk);
|
|
}
|
|
|
|
|
|
void dotest3()
|
|
{
|
|
dlog << LINFO << "start dotest3()";
|
|
print_spinner();
|
|
bool error_occurred = false;
|
|
{
|
|
dlib::pipe<unsigned short> ports(5);
|
|
thread_function t1(callfunct(test3_job, 12345, error_occurred, ports));
|
|
thread_function t2(callfunct(test3_job, 0, error_occurred, ports));
|
|
thread_function t3(callfunct(test3_job, 12347, error_occurred, ports));
|
|
|
|
|
|
try
|
|
{
|
|
std::vector<network_address> hosts;
|
|
unsigned short port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
int result = 0;
|
|
const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2;
|
|
bsp_connect(hosts, test3_job_driver, dlib::ref(result));
|
|
|
|
dlog << LINFO << "result: " << result;
|
|
dlog << LINFO << "should be: " << expected;
|
|
DLIB_TEST(result == expected);
|
|
}
|
|
catch (std::exception& e)
|
|
{
|
|
dlog << LERROR << "error during bsp_context: " << e.what();
|
|
DLIB_TEST(false);
|
|
}
|
|
|
|
}
|
|
DLIB_TEST(error_occurred == false);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
void test4_job_driver(bsp_context& obj, int& result)
|
|
{
|
|
|
|
obj.broadcast(obj.node_id());
|
|
|
|
int accum = 0;
|
|
int temp = 0;
|
|
while(obj.try_receive(temp))
|
|
accum += temp;
|
|
|
|
// send to node 1 so it can sum everything
|
|
if (obj.node_id() != 1)
|
|
obj.send(accum, 1);
|
|
|
|
while(obj.try_receive(temp))
|
|
accum += temp;
|
|
|
|
// Now hop the accum values along the nodes until the value from node 1 gets to
|
|
// node 0.
|
|
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
|
|
obj.receive(accum);
|
|
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
|
|
obj.receive(accum);
|
|
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
|
|
obj.receive(accum);
|
|
|
|
// this whole block is a noop since it doesn't end up doing anything.
|
|
for (int k = 0; k < 40; ++k)
|
|
{
|
|
dlog << LINFO << "k: " << k;
|
|
for (int i = 0; i < 4; ++i)
|
|
{
|
|
obj.send(accum, (obj.node_id()+1)%obj.number_of_nodes());
|
|
obj.receive(accum);
|
|
|
|
obj.receive();
|
|
}
|
|
}
|
|
|
|
|
|
dlog << LINFO << "TERMINATE";
|
|
if (obj.node_id() == 0)
|
|
result = accum;
|
|
}
|
|
|
|
|
|
void test4_job(bsp_context& obj)
|
|
{
|
|
int junk;
|
|
test4_job_driver(obj, junk);
|
|
}
|
|
|
|
|
|
void dotest4()
|
|
{
|
|
dlog << LINFO << "start dotest4()";
|
|
print_spinner();
|
|
bool error_occurred = false;
|
|
{
|
|
dlib::pipe<unsigned short> ports(5);
|
|
thread_function t1(callfunct(test4_job, 0, error_occurred, ports));
|
|
thread_function t2(callfunct(test4_job, 0, error_occurred, ports));
|
|
thread_function t3(callfunct(test4_job, 0, error_occurred, ports));
|
|
|
|
|
|
try
|
|
{
|
|
std::vector<network_address> hosts;
|
|
unsigned short port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
int result = 0;
|
|
const int expected = 1+2+3 + 0+2+3 + 0+1+3 + 0+1+2;
|
|
bsp_connect(hosts, test4_job_driver, dlib::ref(result));
|
|
|
|
dlog << LINFO << "result: " << result;
|
|
dlog << LINFO << "should be: " << expected;
|
|
DLIB_TEST(result == expected);
|
|
}
|
|
catch (std::exception& e)
|
|
{
|
|
dlog << LERROR << "error during bsp_context: " << e.what();
|
|
DLIB_TEST(false);
|
|
}
|
|
|
|
}
|
|
DLIB_TEST(error_occurred == false);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
void test5_job(
|
|
bsp_context& ,
|
|
int& val
|
|
)
|
|
{
|
|
val = 25;
|
|
}
|
|
|
|
void dotest5()
|
|
{
|
|
dlog << LINFO << "start dotest5()";
|
|
print_spinner();
|
|
std::vector<network_address> hosts;
|
|
int val = 0;
|
|
bsp_connect(hosts, test5_job, dlib::ref(val));
|
|
DLIB_TEST(val == 25);
|
|
}
|
|
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
double f ( double x)
|
|
{
|
|
return std::pow(x-2.0, 2.0);
|
|
}
|
|
|
|
|
|
void bsp_job_node_0 (
|
|
bsp_context& context,
|
|
double& min_value,
|
|
double& optimal_x
|
|
)
|
|
{
|
|
double left = -100;
|
|
double right = 100;
|
|
|
|
min_value = std::numeric_limits<double>::infinity();
|
|
double interval_width = std::abs(right-left);
|
|
|
|
// This is doing a BSP based grid search for the minimum of f(). Here we
|
|
// do 100 iterations where we keep shrinking the grid size.
|
|
for (int i = 0; i < 100; ++i)
|
|
{
|
|
context.broadcast(left);
|
|
context.broadcast(right);
|
|
|
|
for (unsigned int k = 1; k < context.number_of_nodes(); ++k)
|
|
{
|
|
std::pair<double,double> val;
|
|
context.receive(val);
|
|
if (val.second < min_value)
|
|
{
|
|
min_value = val.second;
|
|
optimal_x = val.first;
|
|
}
|
|
}
|
|
|
|
interval_width *= 0.5;
|
|
left = optimal_x - interval_width/2;
|
|
right = optimal_x + interval_width/2;
|
|
}
|
|
}
|
|
|
|
|
|
void bsp_job_other_nodes (
|
|
bsp_context& context
|
|
)
|
|
{
|
|
double left, right;
|
|
while (context.try_receive(left))
|
|
{
|
|
context.receive(right);
|
|
|
|
const double l = (context.node_id()-1)/(context.number_of_nodes()-1.0);
|
|
const double r = context.node_id() /(context.number_of_nodes()-1.0);
|
|
|
|
const double width = right-left;
|
|
matrix<double> values_to_check = linspace(left +l*width, left + r*width, 100);
|
|
|
|
double best_x = 0;
|
|
double best_val = std::numeric_limits<double>::infinity();
|
|
for (long j = 0; j < values_to_check.size(); ++j)
|
|
{
|
|
double temp = f(values_to_check(j));
|
|
if (temp < best_val)
|
|
{
|
|
best_val = temp;
|
|
best_x = values_to_check(j);
|
|
}
|
|
}
|
|
|
|
context.send(make_pair(best_x, best_val), 0);
|
|
}
|
|
}
|
|
|
|
void dotest6()
|
|
{
|
|
dlog << LINFO << "start dotest6()";
|
|
print_spinner();
|
|
bool error_occurred = false;
|
|
{
|
|
dlib::pipe<unsigned short> ports(5);
|
|
thread_function t1(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
|
|
thread_function t2(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
|
|
thread_function t3(callfunct(bsp_job_other_nodes, 0, error_occurred, ports));
|
|
|
|
|
|
try
|
|
{
|
|
std::vector<network_address> hosts;
|
|
unsigned short port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
ports.dequeue(port); hosts.push_back(network_address("127.0.0.1",port)); dlog << LINFO << "PORT: " << port;
|
|
double min_value = 10, optimal_x = 0;
|
|
bsp_connect(hosts, bsp_job_node_0, dlib::ref(min_value), dlib::ref(optimal_x));
|
|
|
|
dlog << LINFO << "min_value: " << min_value;
|
|
dlog << LINFO << "optimal_x: " << optimal_x;
|
|
DLIB_TEST(std::abs(min_value - 0) < 1e-14);
|
|
DLIB_TEST(std::abs(optimal_x - 2) < 1e-14);
|
|
}
|
|
catch (std::exception& e)
|
|
{
|
|
dlog << LERROR << "error during bsp_context: " << e.what();
|
|
DLIB_TEST(false);
|
|
}
|
|
|
|
}
|
|
DLIB_TEST(error_occurred == false);
|
|
}
|
|
// ----------------------------------------------------------------------------------------
|
|
|
|
class bsp_tester : public tester
|
|
{
|
|
|
|
public:
|
|
bsp_tester (
|
|
) :
|
|
tester ("test_bsp",
|
|
"Runs tests on the BSP components.")
|
|
{}
|
|
|
|
void perform_test (
|
|
)
|
|
{
|
|
for (int i = 0; i < 3; ++i)
|
|
{
|
|
dotest1();
|
|
dotest2<0>();
|
|
dotest2<1>();
|
|
dotest2<2>();
|
|
dotest3();
|
|
dotest4();
|
|
dotest5();
|
|
dotest6();
|
|
}
|
|
}
|
|
} a;
|
|
|
|
}
|
|
|