/*
  TODO:
  search $PATH
  exit if execve fails in child
  do -r in a cheaper fashion: cache or just read binary from disk?
  symbols!
   libbfs?
*/
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <sys/ptrace.h>
#include <assert.h>
#include <errno.h>
#include <sys/types.h>
#include <sys/wait.h>
#include <ctype.h>
#include <sys/user.h>
#include <unistd.h>
#include <signal.h>
#include <stdint.h>
#include <syscall.h>
#include <sys/mman.h>

#if __WORDSIZE == 64
#define PC_REG rip
#else
#define PC_REG eip
#endif

typedef struct {
    int pid;
    char **argv;
    char **envp;
    int maxsteps;
    char *outfile;
    FILE *outfd;
    int binary;
    int only_rets;
    char *mapfile;
    FILE *mapfd;
    char *sysfile;
    FILE *sysfd;
    int interactive;
} opt_t;

static opt_t opt;
static int time_to_detach;

void handle_int(int dummy) {
    time_to_detach = 1;
}

int do_forkexec(char **argv, char **envp) {
    int pid;
    pid = fork();
    switch(pid) {
    case -1:
        assert(0);
    case 0:
        ptrace(PTRACE_TRACEME, 0, 0, 0);
        execve(argv[0], argv, envp);
        assert(0);
    default:
        assert(pid > 0);
        return pid;
    }
}

void usage(char *binary) {
    fprintf(stderr,
            "usage: %s [OPTION]... [COMMAND]\n"
            "-p <pid>          attach to already running process\n"
            "-c <count>        trace only count instructions\n"
            "-o <file>         print output to file instead of stdout\n"
            "-b                use binary output format instead of ASCII\n"
            "-r                print EIP only when it points to RET instruction\n"
            "-m <file>         write info about changes in memory mappings to file\n"
            "-s <file>         write trace of syscalls to file (like strace, used for internal debugging)\n"
            "-i                interactive mode, currently just stops after each instruction\n",
            binary);
}

void parse_options(int argc, char **argv, char **envp) {
    int ret;

    memset(&opt, 0, sizeof(opt_t));
    opt.outfd = stdout;
    for (;;) {
        ret = getopt(argc, argv, "+p:c:o:brm:s:i");
        if (ret == -1) {
            break;
        }
        switch (ret) {
        case 'p':
            opt.pid = atoi(optarg);
            assert(opt.pid);
            break;
        case 'c':
            opt.maxsteps = atoi(optarg);
            break;
        case 'o':
            opt.outfile = optarg;
            opt.outfd = fopen(opt.outfile, "w");
            assert(opt.outfd);
            break;
        case 'b':
            opt.binary = 1;
            break;
        case 'r':
            opt.only_rets = 1;
            break;
        case 'm':
            opt.mapfile = optarg;
            opt.mapfd = fopen(opt.mapfile, "w");
            assert(opt.mapfd);
            break;
        case 's':
            opt.sysfile = optarg;
            opt.sysfd = fopen(opt.sysfile, "w");
            assert(opt.sysfd);
            break;
        case 'i':
            opt.interactive = 1;
            break;
        default:
            usage(argv[0]);
            exit(1);
        }
    }
    if (!opt.pid) {
        if (optind == argc) {
            usage(argv[0]);
            exit(1);
        }
        opt.argv = argv+optind;
        opt.envp = envp;
    }
}

int main(int argc, char **argv, char **envp) {
    int ret, pid, steps = 0, maps_dirty = 1;
    struct user_regs_struct regs;

    parse_options(argc, argv, envp);
    pid = opt.pid;
    if (pid) {
        ret = ptrace(PTRACE_ATTACH, pid, 0, 0);
        assert(!ret);
    } else {
        pid = do_forkexec(opt.argv, opt.envp);
    }
    
    signal(SIGINT, handle_int);
    signal(SIGTERM, handle_int);
    signal(SIGHUP, handle_int);
    for(;;) {
        ret = wait(0);
        assert(ret == pid);
        ret = ptrace(PTRACE_GETREGS, pid, 0, &regs);
        if (ret) {
            /* Target process exited? */
            assert(ret == -1); 
            break;
        }
        if (time_to_detach
            || (opt.maxsteps && steps >= opt.maxsteps)) {
            if (opt.pid) {
                ret = ptrace(PTRACE_DETACH, pid, 0, 0);
                assert(!ret);
            } else {
                ret = ptrace(PTRACE_KILL, pid, 0, 0);
                assert(!ret);
            }
            break;
        }
        {
            int opcode, ret_detected, int_detected;
            if (opt.only_rets || opt.mapfile || opt.sysfile) {
                /* use cache here? every extra syscall counts... */
                opcode = ptrace(PTRACE_PEEKTEXT, pid, regs.PC_REG, NULL);
                ret_detected = ((opcode & 0xff) == 0xc3 || (opcode & 0xff) == 0xc2);
#if __WORDSIZE == 64
                int_detected = ((opcode & 0xffff) == 0x050F); /* syscall */
#else
                int_detected = ((opcode & 0xffff) == 0x80CD);
#endif
            }
            if (opt.sysfile) {
                if (int_detected) {
                    fprintf(opt.sysfd, "%d syscall(%ld, %ld, %ld, %ld, %ld, %ld)\n",
			    steps,
#if __WORDSIZE == 64
                            regs.rax, regs.rbx, regs.rcx, regs.rdx, regs.rsi, regs.rdi);
#else
                            regs.eax, regs.ebx, regs.ecx, regs.edx, regs.esi, regs.edi);
#endif
                }
            }
            if (opt.mapfile) {
		if (maps_dirty) {
		    char mapspath[FILENAME_MAX];
		    FILE *maps;
		    int ch;
		    
		    sprintf(mapspath, "/proc/%d/maps", pid);
		    maps = fopen(mapspath, "r");
		    assert(maps);
		    fprintf(opt.mapfd, "%d\n", steps);
		    while ((ch = fgetc(maps)) != EOF) {
			fputc(ch, opt.mapfd);
		    }
		    fclose(maps);
		    maps_dirty = 0;
		}
                if (int_detected) {
#if __WORDSIZE == 64
		    switch (regs.rax) {
#else
		    switch (regs.eax) {
#endif
		    case SYS_mmap:
#if __WORDSIZE != 64
		    case SYS_mmap2:
#endif
		    case SYS_mremap:
		    case SYS_munmap:
		    case SYS_uselib:
			maps_dirty = 1;
			break;
		    default:
			break;
		    }
		}
            }
            if (!opt.only_rets || ret_detected) {
                if (opt.binary) {
                    fwrite(&regs.PC_REG, sizeof(regs.PC_REG), 1, opt.outfd);
                } else {
                    fprintf(opt.outfd, "%p\n", (void*)regs.PC_REG);
                }
            }
        }
        if (opt.interactive) {
            getchar();
        }
        ret = ptrace(PTRACE_SINGLESTEP, pid, 0, 0);
        assert(!ret);
	steps++;
    }
    if (opt.outfile) {
        fclose(opt.outfd);
    }
    if (opt.sysfile) {
        fclose(opt.sysfd);
    }
    if (opt.mapfile) {
        fclose(opt.mapfd);
    }
    return 0;
}
