#include <unistd.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#define MAX_INPUTS 128
#define MAX_REDUCERS 16

static void write_all(int fd, const char *buf, size_t len)
{
    while (len > 0) {
        ssize_t n = write(fd, buf, len);
        if (n < 0) {
            if (errno == EINTR) {
                continue;
            }
            _exit(1);
        }
        buf += n;
        len -= (size_t)n;
    }
}

static void write_str(int fd, const char *s)
{
    write_all(fd, s, strlen(s));
}

static void log_errno_msg(const char *prefix)
{
    char buf[512];
    int n = snprintf(buf, sizeof(buf), "%s: %s\n", prefix, strerror(errno));
    if (n > 0) {
        write_all(STDERR_FILENO, buf, (size_t)n);
    }
}

static void log_usage(const char *progname)
{
    char buf[1024];
    int n = snprintf(
        buf, sizeof(buf),
        "Uso: %s [-m] [-r nreducers] [-o output_file] <input_files...>\n",
        progname
    );
    if (n > 0) {
        write_all(STDERR_FILENO, buf, (size_t)n);
    }
}

static int ensure_dir_exists(const char *path)
{
    struct stat st;

    if (stat(path, &st) == 0) {
        if (S_ISDIR(st.st_mode)) {
            return 0;
        }
        errno = ENOTDIR;
        return -1;
    }

    if (mkdir(path, 0755) < 0) {
        return -1;
    }

    return 0;
}

static void log_created(const char *stage, const char *cmd, pid_t pid)
{
    char buf[256];
    int n = snprintf(buf, sizeof(buf),
                     "[WordCount:%s] Created %s process %d.\n",
                     stage, cmd, (int)pid);
    if (n > 0) {
        write_all(STDOUT_FILENO, buf, (size_t)n);
    }
}

static int wait_for_pid(pid_t pid, int *exit_status)
{
    int status;

    while (waitpid(pid, &status, 0) < 0) {
        if (errno == EINTR) {
            continue;
        }
        return -1;
    }

    if (WIFEXITED(status)) {
        *exit_status = WEXITSTATUS(status);
        return 0;
    }

    *exit_status = -1;
    return -1;
}

static pid_t spawn_process(const char *stage, const char *program, char *const argv[])
{
    pid_t pid = fork();

    if (pid < 0) {
        return -1;
    }

    if (pid == 0) {
        execv(program, argv);
        log_errno_msg(program);
        _exit(127);
    }

    log_created(stage, program, pid);
    return pid;
}

static void build_tmp_path(char *out, size_t out_size,
                           const char *dir, const char *prefix, int index)
{
    snprintf(out, out_size, "%s/%s%d.tmp", dir, prefix, index);
}

static void cleanup_tmp_files(const char *tmp_dir, int max_maps, int max_reducers)
{
    int i;
    char path[512];

    for (i = 0; i < max_maps; ++i) {
        build_tmp_path(path, sizeof(path), tmp_dir, "_map_", i);
        unlink(path);
    }

    for (i = 0; i < max_reducers; ++i) {
        build_tmp_path(path, sizeof(path), tmp_dir, "_batch_", i);
        unlink(path);

        build_tmp_path(path, sizeof(path), tmp_dir, "_suffle_", i);
        unlink(path);
    }
}

static void cleanup_result_files(const char *base_output, int n_reducers)
{
    int i;
    char out[512];
    const char *dot = strrchr(base_output, '.');
    const char *slash = strrchr(base_output, '/');

    unlink(base_output);

    if (n_reducers <= 1) {
        return;
    }

    for (i = 0; i < n_reducers; ++i) {
        if (dot != NULL && (slash == NULL || dot > slash)) {
            size_t prefix_len = (size_t)(dot - base_output);
            snprintf(out, sizeof(out), "%.*s_%d%s",
                     (int)prefix_len, base_output, i + 1, dot);
        } else {
            snprintf(out, sizeof(out), "%s_%d", base_output, i + 1);
        }
        unlink(out);
    }
}

static void build_reduce_output(char *out, size_t out_size,
                                const char *base_output, int reducer_index, int total_reducers)
{
    const char *dot;
    const char *slash;

    if (total_reducers <= 1) {
        snprintf(out, out_size, "%s", base_output);
        return;
    }

    dot = strrchr(base_output, '.');
    slash = strrchr(base_output, '/');

    if (dot != NULL && (slash == NULL || dot > slash)) {
        size_t prefix_len = (size_t)(dot - base_output);
        snprintf(out, out_size, "%.*s_%d%s",
                 (int)prefix_len, base_output, reducer_index + 1, dot);
    } else {
        snprintf(out, out_size, "%s_%d", base_output, reducer_index + 1);
    }
}

int main(int argc, char *argv[])
{
    int multi_map = 0;
    int n_reducers = 1;
    const char *output_file = "./Result/wc.txt";
    const char *tmp_dir = "./tmp";

    char *inputs[MAX_INPUTS];
    int n_inputs = 0;
    int i;

    pid_t map_pids[MAX_INPUTS];
    int map_status[MAX_INPUTS];

    pid_t reduce_pids[MAX_REDUCERS];
    int reduce_status[MAX_REDUCERS];

    int status_suffle = 0;

    char *map_argv[MAX_INPUTS + 8];
    char *suffle_argv[8];
    char *reduce_argv[8];

    char map_index_buf[MAX_INPUTS][32];
    char nmap_buf[32];
    char nr_buf[32];
    char part_buf[MAX_REDUCERS][32];
    char reduce_out_buf[MAX_REDUCERS][512];

    if (argc < 2) {
        log_usage(argv[0]);
        return 1;
    }

    for (i = 1; i < argc; ++i) {
        if (strcmp(argv[i], "-m") == 0) {
            multi_map = 1;
        } else if (strcmp(argv[i], "-r") == 0) {
            if (i + 1 >= argc) {
                log_usage(argv[0]);
                return 1;
            }
            n_reducers = atoi(argv[++i]);
            if (n_reducers <= 0 || n_reducers > MAX_REDUCERS) {
                write_str(STDERR_FILENO, "Error: nreducers debe estar entre 1 y 16.\n");
                return 1;
            }
        } else if (strcmp(argv[i], "-o") == 0) {
            if (i + 1 >= argc) {
                log_usage(argv[0]);
                return 1;
            }
            output_file = argv[++i];
        } else {
            if (n_inputs >= MAX_INPUTS) {
                write_str(STDERR_FILENO, "Error: demasiados ficheros de entrada.\n");
                return 1;
            }
            inputs[n_inputs++] = argv[i];
        }
    }

    if (n_inputs == 0) {
        log_usage(argv[0]);
        return 1;
    }

    if (ensure_dir_exists(tmp_dir) < 0) {
        log_errno_msg("tmp_dir");
        return 1;
    }

    if (ensure_dir_exists("./Result") < 0) {
        log_errno_msg("Result");
        return 1;
    }

    cleanup_tmp_files(tmp_dir, n_inputs, n_reducers);
    cleanup_result_files(output_file, n_reducers);

    /* ---------- MAP ---------- */
    if (multi_map) {
        for (i = 0; i < n_inputs; ++i) {
            int pos = 0;

            snprintf(map_index_buf[i], sizeof(map_index_buf[i]), "%d", i);

            map_argv[pos++] = "./WordCountMap";
            map_argv[pos++] = "1";
            map_argv[pos++] = (char *)tmp_dir;
            map_argv[pos++] = map_index_buf[i];
            map_argv[pos++] = inputs[i];
            map_argv[pos] = NULL;

            map_pids[i] = spawn_process("map", "./WordCountMap", map_argv);
            if (map_pids[i] < 0) {
                log_errno_msg("spawn WordCountMap");
                return 1;
            }
        }

        for (i = 0; i < n_inputs; ++i) {
            if (wait_for_pid(map_pids[i], &map_status[i]) < 0) {
                log_errno_msg("wait mapper");
                return 1;
            }

            {
                char buf[256];
                int n = snprintf(buf, sizeof(buf),
                                 "[WordCount:map] Mapper %d returned %d.\n",
                                 i, map_status[i]);
                if (n > 0) {
                    write_all(STDOUT_FILENO, buf, (size_t)n);
                }
            }
        }
    } else {
        int pos = 0;
        pid_t pid;
        int status_map;

        map_argv[pos++] = "./WordCountMap";
        map_argv[pos++] = "0";
        map_argv[pos++] = (char *)tmp_dir;
        map_argv[pos++] = "0";

        for (i = 0; i < n_inputs; ++i) {
            map_argv[pos++] = inputs[i];
        }
        map_argv[pos] = NULL;

        pid = spawn_process("map", "./WordCountMap", map_argv);
        if (pid < 0) {
            log_errno_msg("spawn WordCountMap");
            return 1;
        }

        if (wait_for_pid(pid, &status_map) < 0) {
            log_errno_msg("wait mapper");
            return 1;
        }

        {
            char buf[256];
            int n = snprintf(buf, sizeof(buf),
                             "[WordCount:map] Mapper returned %d.\n",
                             status_map);
            if (n > 0) {
                write_all(STDOUT_FILENO, buf, (size_t)n);
            }
        }
    }

    /* ---------- SUFFLE ---------- */
    snprintf(nmap_buf, sizeof(nmap_buf), "%d", multi_map ? n_inputs : 1);
    snprintf(nr_buf, sizeof(nr_buf), "%d", n_reducers);

    suffle_argv[0] = "./WordCountSuffle";
    suffle_argv[1] = (char *)tmp_dir;
    suffle_argv[2] = nmap_buf;
    suffle_argv[3] = nr_buf;
    suffle_argv[4] = NULL;

    {
        pid_t pid = spawn_process("suffle", "./WordCountSuffle", suffle_argv);
        if (pid < 0) {
            log_errno_msg("spawn WordCountSuffle");
            return 1;
        }

        if (wait_for_pid(pid, &status_suffle) < 0 || status_suffle != 0) {
            write_str(STDERR_FILENO, "[WordCount:suffle] Error running WordCountSuffle.\n");
            return 1;
        }
    }

    /* ---------- REDUCE ---------- */
    for (i = 0; i < n_reducers; ++i) {
        snprintf(part_buf[i], sizeof(part_buf[i]), "%d", i);
        build_reduce_output(reduce_out_buf[i], sizeof(reduce_out_buf[i]),
                            output_file, i, n_reducers);

        reduce_argv[0] = "./WordCountReduce";
        reduce_argv[1] = (char *)tmp_dir;
        reduce_argv[2] = part_buf[i];
        reduce_argv[3] = reduce_out_buf[i];
        reduce_argv[4] = NULL;

        reduce_pids[i] = spawn_process("reduce", "./WordCountReduce", reduce_argv);
        if (reduce_pids[i] < 0) {
            log_errno_msg("spawn WordCountReduce");
            return 1;
        }
    }

    for (i = 0; i < n_reducers; ++i) {
        if (wait_for_pid(reduce_pids[i], &reduce_status[i]) < 0 || reduce_status[i] != 0) {
            write_str(STDERR_FILENO, "[WordCount:reduce] Error running WordCountReduce.\n");
            return 1;
        }
    }

    {
        char buf[512];
        int n = snprintf(buf, sizeof(buf),
                         "[WordCount:wordcountmr] WordCountMR completed processing %d file(s) with %d reducer(s).\n",
                         n_inputs, n_reducers);
        if (n > 0) {
            write_all(STDOUT_FILENO, buf, (size_t)n);
        }
    }

    return 0;
}
