#include <unistd.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <sys/stat.h>
#include <fcntl.h>
#include <errno.h>
#include <stdlib.h>
#include <string.h>
#include <stdio.h>

#define TMP1   "tmp1.txt"
#define TMP2   "tmp2.txt"
#define TMP3   "tmp3.txt"
#define TMP4   "tmp4.txt"
#define RESULT "result.txt"

static void write_all(int fd, const char *buf, size_t len)
{
    while (len > 0) {
        ssize_t n = write(fd, buf, len);
        if (n < 0) {
            if (errno == EINTR) {
                continue;
            }
            _exit(1);
        }
        buf += n;
        len -= (size_t)n;
    }
}

static void write_str(int fd, const char *s)
{
    write_all(fd, s, strlen(s));
}

static void log_errno_msg(const char *prefix)
{
    char buf[512];
    int n = snprintf(buf, sizeof(buf), "%s: %s\n", prefix, strerror(errno));
    if (n > 0) {
        write_all(STDERR_FILENO, buf, (size_t)n);
    }
}

static void log_usage(const char *progname)
{
    char buf[512];
    int n = snprintf(buf, sizeof(buf),
                     "Uso: %s <input_file1> [input_file2 ...]\n",
                     progname);
    if (n > 0) {
        write_all(STDERR_FILENO, buf, (size_t)n);
    }
}

static void log_created(const char *stage, const char *cmd, pid_t pid)
{
    char buf[256];
    int n = snprintf(buf, sizeof(buf),
                     "[WordCount:%s] Created %s process %d.\n",
                     stage, cmd, (int)pid);
    if (n > 0) {
        write_all(STDOUT_FILENO, buf, (size_t)n);
    }
}

static void log_failed(const char *stage, const char *cmd, const char *reason)
{
    char buf[256];
    int n = snprintf(buf, sizeof(buf),
                     "[WordCount:%s] Error in %s (%s).\n",
                     stage, cmd, reason);
    if (n > 0) {
        write_all(STDERR_FILENO, buf, (size_t)n);
    }
}

static int redirect_input(const char *path)
{
    int fd = open(path, O_RDONLY);
    if (fd < 0) {
        return -1;
    }

    if (dup2(fd, STDIN_FILENO) < 0) {
        int saved = errno;
        close(fd);
        errno = saved;
        return -1;
    }

    if (close(fd) < 0) {
        return -1;
    }

    return 0;
}

static int redirect_output(const char *path)
{
    int fd = open(path, O_WRONLY | O_CREAT | O_TRUNC, 0644);
    if (fd < 0) {
        return -1;
    }

    if (dup2(fd, STDOUT_FILENO) < 0) {
        int saved = errno;
        close(fd);
        errno = saved;
        return -1;
    }

    if (close(fd) < 0) {
        return -1;
    }

    return 0;
}

static int run_stage(const char *stage,
                     const char *cmd,
                     char *const argv[],
                     const char *input_path,
                     const char *output_path)
{
    pid_t pid = fork();
    int status;

    if (pid < 0) {
        log_failed(stage, cmd, "fork");
        log_errno_msg("fork");
        return -1;
    }

    if (pid == 0) {
        if (input_path != NULL && redirect_input(input_path) < 0) {
            log_errno_msg("redirect input");
            _exit(127);
        }

        if (output_path != NULL && redirect_output(output_path) < 0) {
            log_errno_msg("redirect output");
            _exit(127);
        }

        execvp(cmd, argv);

        log_errno_msg(cmd);
        _exit(127);
    }

    log_created(stage, cmd, pid);

    while (waitpid(pid, &status, 0) < 0) {
        if (errno == EINTR) {
            continue;
        }
        log_failed(stage, cmd, "waitpid");
        log_errno_msg("waitpid");
        return -1;
    }

    if (WIFEXITED(status) && WEXITSTATUS(status) == 0) {
        return 0;
    }

    if (WIFSIGNALED(status)) {
        log_failed(stage, cmd, "terminated by signal");
    } else {
        log_failed(stage, cmd, "non-zero exit status");
    }

    return -1;
}

static char **build_grep_argv(int argc, char *argv[])
{
    int add_hide_names = (argc > 2) ? 1 : 0;
    int base = 3 + add_hide_names; /* grep [-h] -oE PATTERN */
    char **grep_argv = (char **)malloc(sizeof(char *) * (size_t)(base + argc));
    int i;
    int pos = 0;

    if (grep_argv == NULL) {
        return NULL;
    }

    grep_argv[pos++] = "grep";
    if (add_hide_names) {
        grep_argv[pos++] = "-h";
    }
    grep_argv[pos++] = "-oE";
    grep_argv[pos++] = "[[:alpha:]]+";

    for (i = 1; i < argc; ++i) {
        grep_argv[pos++] = argv[i];
    }
    grep_argv[pos] = NULL;

    return grep_argv;
}

int main(int argc, char *argv[])
{
    char **grep_argv;
    char *tr_argv[]      = {"tr", "[:upper:]", "[:lower:]", NULL};
    char *sort_argv[]    = {"sort", NULL};
    char *uniq_argv[]    = {"uniq", "-c", NULL};
    char *sort_nr_argv[] = {"sort", "-nr", NULL};

    if (argc < 2) {
        log_usage(argv[0]);
        return 1;
    }

    grep_argv = build_grep_argv(argc, argv);
    if (grep_argv == NULL) {
        log_errno_msg("malloc");
        return 1;
    }

    if (run_stage("map", "grep", grep_argv, NULL, TMP1) != 0) {
        free(grep_argv);
        return 1;
    }

    if (run_stage("map", "tr", tr_argv, TMP1, TMP2) != 0) {
        free(grep_argv);
        return 1;
    }

    if (run_stage("suffle", "sort", sort_argv, TMP2, TMP3) != 0) {
        free(grep_argv);
        return 1;
    }

    if (run_stage("reduce", "uniq", uniq_argv, TMP3, TMP4) != 0) {
        free(grep_argv);
        return 1;
    }

    if (run_stage("reduce", "sort", sort_nr_argv, TMP4, RESULT) != 0) {
        free(grep_argv);
        return 1;
    }

    free(grep_argv);

    write_str(STDOUT_FILENO,
              "[WordCount:wordcount] WordCount completed successfully.\n");
    return 0;
}
