/* Test groups of 20 processes spraying to 20 receivers */
#include <pthread.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <errno.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/wait.h>
#include <sys/time.h>
#include <sys/poll.h>

#define DATASIZE 100
static unsigned int loops = 100;
static int use_pipes = 0;
static int realtime = 0;

struct sender_context {
	unsigned int num_fds;
	int ready_out;
	int wakefd;
	int out_fds[0];
};

struct receiver_context {
	unsigned int num_packets;
	int in_fd;
	int ready_out;
	int wakefd;
};


static void barf(const char *msg)
{
        fprintf(stderr, "%s (error: %s)\n", msg, strerror(errno));
        exit(1);
}

static void fdpair(int fds[2])
{
        if (use_pipes) {
                if (pipe(fds) == 0)
                        return;
        } else {
                if (socketpair(AF_UNIX, SOCK_STREAM, 0, fds) == 0)
                        return;
        }
        barf("Creating fdpair");
}

/* Block until we're ready to go */
static void ready(int ready_out, int wakefd)
{
        char dummy;
        struct pollfd pollfd = { .fd = wakefd, .events = POLLIN };

        /* Tell them we're ready. */
        if (write(ready_out, &dummy, 1) != 1)
                barf("CLIENT: ready write");

        /* Wait for "GO" signal */
        if (poll(&pollfd, 1, -1) != 1)
                barf("poll");
}

/* Sender sprays loops messages down each file descriptor */
static void *sender(struct sender_context *ctx)
{
        char data[DATASIZE];
        unsigned int i, j;

        ready(ctx->ready_out, ctx->wakefd);

        /* Now pump to every receiver. */
        for (i = 0; i < loops; i++) {
                for (j = 0; j < ctx->num_fds; j++) {
                        int ret, done = 0;

                again:
                        ret = write(ctx->out_fds[j], data + done, sizeof(data)-done);
                        if (ret < 0)
                                barf("SENDER: write");
                        done += ret;
                        if (done < sizeof(data))
                                goto again;
                }
        }
	pthread_exit(0);
}

/* One receiver per fd */
static void *receiver(struct receiver_context* ctx)
{
        unsigned int i;

        /* Wait for start... */
        ready(ctx->ready_out, ctx->wakefd);

        /* Receive them all */
        for (i = 0; i < ctx->num_packets; i++) {
                char data[DATASIZE];
                int ret, done = 0;

        again:
                ret = read(ctx->in_fd, data + done, DATASIZE - done);
                if (ret < 0)
                        barf("SERVER: read");
                done += ret;
                if (done < DATASIZE)
                        goto again;
        }
	pthread_exit(0);
}

pthread_t create_thread(void *ctx, void *(*func)(void *))
{
	pthread_attr_t attr;
	pthread_t childid;
	int priomin, priomax, err;
	struct sched_param schparm;
	static int numth = 0;

	if (pthread_attr_init(&attr) != 0)
		barf("pthread_attr_init:");

	/*if (realtime) {
		priomax = sched_get_priority_max(SCHED_FIFO);
		if (priomax == -1)
			barf("sched_get_priority_max:");

		priomin = sched_get_priority_min(SCHED_FIFO);
		if (priomin == -1)
                        barf("sched_get_priority_min:");


		//schparm.sched_priority = priomin+ (int)(( (float)(priomax-priomin+1) )*rand()/(RAND_MAX+1.0));
		schparm.sched_priority = priomin+ (numth % (priomax+1 - priomin));
		numth++;

		if (pthread_attr_setschedpolicy(&attr, SCHED_FIFO) != 0)
			barf("pthread_attr_setschedpolicy");

		if (pthread_attr_setschedparam(&attr, &schparm) != 0)
			barf("pthread_attr_setschedparam");

		if (pthread_attr_setinheritsched(&attr, PTHREAD_EXPLICIT_SCHED) != 0)
			barf("pthread_attr_setinheritsched");
			}*/
	if (pthread_attr_setstacksize(&attr, (size_t)(16*1024)) != 0)
		barf("pthread_attr_setstacksize");

	if ((err=pthread_create(&childid, &attr, func, ctx)) != 0) {
		fprintf(stderr, "pthread_create failed: %s (%d)\n", strerror(err), err);
		exit(-1);
	}
	return (childid);
}

/* One group of senders and receivers */
static unsigned int group(pthread_t *pth,
			  unsigned int num_fds,
                          int ready_out,
                          int wakefd)
{
        unsigned int i;
	struct sender_context* snd_ctx = malloc (sizeof(struct sender_context)
						 +num_fds*sizeof(int));

        for (i = 0; i < num_fds; i++) {
                int fds[2];
		struct receiver_context* ctx = malloc (sizeof(*ctx));
		
		if (!ctx)
			barf("malloc()");


                /* Create the pipe between client and server */
                fdpair(fds);

		ctx->num_packets = num_fds*loops;
		ctx->in_fd = fds[0];
		ctx->ready_out = ready_out;
		ctx->wakefd = wakefd;

                pth[i] = create_thread(ctx, (void *)(void *)receiver);

                snd_ctx->out_fds[i] = fds[1];
        }

        /* Now we have all the fds, fork the senders */
        for (i = 0; i < num_fds; i++) {
		snd_ctx->ready_out = ready_out;
		snd_ctx->wakefd = wakefd;
		snd_ctx->num_fds = num_fds;


		pth[num_fds+i] = create_thread(snd_ctx, (void *)(void *)sender);
        }

        /* Return number of children to reap */
        return num_fds * 2;
}

int main(int argc, char *argv[])
{
        unsigned int i, num_groups, total_children;
        struct timeval start, stop, diff;
        unsigned int num_fds = 20;
        int readyfds[2], wakefds[2];
        char dummy;
	pthread_t *pth_tab;

        if (argv[1] && strcmp(argv[1], "-pipe") == 0) {
                use_pipes = 1;
                argc--;
                argv++;
        }

        if (argc != 2 || (num_groups = atoi(argv[1])) == 0)
                barf("Usage: hackbench [-pipe] <num groups>\n");

	pth_tab = malloc(num_fds * 2 * num_groups * sizeof(pthread_t));

	if (!pth_tab)
		barf("main:malloc()");

        fdpair(readyfds);
        fdpair(wakefds);

        total_children = 0;
        for (i = 0; i < num_groups; i++)
                total_children += group(pth_tab+total_children, num_fds, readyfds[1], wakefds[0]);

        /* Wait for everyone to be ready */
        for (i = 0; i < total_children; i++)
                if (read(readyfds[0], &dummy, 1) != 1)
                        barf("Reading for readyfds");

        gettimeofday(&start, NULL);

        /* Kick them off */
        if (write(wakefds[1], &dummy, 1) != 1)
                barf("Writing to start them");

        /* Reap them all */
        for (i = 0; i < total_children; i++) {
                int status;
		pthread_join(pth_tab[i], (void**)&status);

                /*if (!status)
		  exit(1);*/
        }

        gettimeofday(&stop, NULL);

        /* Print time... */
        timersub(&stop, &start, &diff);
        printf("Time: %lu.%03lu\n", diff.tv_sec, diff.tv_usec/1000);
        exit(0);
}
