/*===========================================================================
*
*                            PUBLIC DOMAIN NOTICE
*               National Center for Biotechnology Information
*
*  This software/database is a "United States Government Work" under the
*  terms of the United States Copyright Act.  It was written as part of
*  the author's official duties as a United States Government employee and
*  thus cannot be copyrighted.  This software/database is freely available
*  to the public for use. The National Library of Medicine and the U.S.
*  Government have not placed any restriction on its use or reproduction.
*
*  Although all reasonable efforts have been taken to ensure the accuracy
*  and reliability of the software and data, the NLM and the U.S.
*  Government do not and cannot warrant the performance or results that
*  may be obtained by using this software or data. The NLM and the U.S.
*  Government disclaim all warranties, express or implied, including
*  warranties of performance, merchantability or fitness for any particular
*  purpose.
*
*  Please cite the author in any work or product based on this material.
*
* ===========================================================================
*
*/

#include <klib/extern.h>
#include "log-priv.h"
#include <klib/writer.h>
#include "writer-priv.h"
#include <klib/text.h>
#include <klib/rc.h>
#include <klib/symbol.h>
#include <sysalloc.h>
#include <os-native.h> /* for strchrnul on non-linux */

#include <assert.h>
#include <string.h>
#include <stdarg.h>
#include <stdio.h>
#include <stdlib.h>
#include <ctype.h>

static
size_t CC copy_chars ( char ** dst, size_t * lim, const char * src, size_t cnt )
{
    if (cnt > *lim)
        cnt = *lim;
    memcpy (*dst, src, cnt);
    *lim -= cnt;
    *dst += cnt;
    return cnt;
}

/*
 * send all characters before a '%' straight on through
 * but handle %% -> % here
 */
static
size_t CC normal_chars ( char ** _pout, const char ** _pin, size_t * _limit )
{
    char * pout;
    const char * pin;
    size_t count;
    size_t limit = *_limit;

    assert (_pin);
    assert (_pout);
    assert (*_pin);
    assert (*_pout);
    
    /* make pointers 'local' and initialize count */
    pout = *_pout;
    pin = *_pin;
    count = 0;

    /* loop in case we hit a "%%" */
    while ((*pin) && (limit > count))
    {
        char *ppc;       /* pointer to per cent */
        size_t seg_size; /* segment size */

        ppc = strchrnul ( pin, '%' );
        seg_size = ( ppc - pin );

        /* if we hit the end of the output buffer
         * copy what we can
         */
        if (seg_size >= limit)
        {
            seg_size = limit-1;
            memcpy (pout, pin, seg_size);
            pout += seg_size;
            pin += seg_size;
            count += seg_size;
            break;
        }

        if (seg_size)
        {
            /* copy bytes to the output */
            memcpy (pout, pin, seg_size);
            /* fix up pointers and count */
            pout += seg_size;
            pin += seg_size;
            count += seg_size;
            limit -= seg_size;
            *pout = '\0';
        }
        /* if we foun a '%', check for another */ 
        if ((*pin == '%') && (*(pin+1) == '%'))
        {
            /* copy a single '%' and fix up pointers and count */
            *pout++ = '%';
            pin += 2;
            count++;
            limit--;
        }
        else
            break;
    }

    /* write back new pointer values */
    *_pout = pout;
    *_pin = pin;
    *_limit = limit;

    return count;
}

static
size_t CC KWrtFmt_KSymbol ( char * buffer, size_t bufsize, char * ignored_fmt, KSymbol * sym )
{
    size_t total;

    total = 0;
    if (sym == NULL)
    {
        static const char null[] = "(null)";
        memcpy (buffer, null, (sizeof (null) > bufsize) ? bufsize : sizeof (null));
        return (sizeof (null));
    }
    if (sym->dad)
    {
        size_t used;

        used = knprintf (buffer, bufsize, "%N:", sym->dad);
        buffer += used;
        total = used;
        if (used > bufsize)
            bufsize = 0;
        else
            bufsize -= used;
    }
    total += knprintf (buffer, bufsize, "%S", & sym->name);
    return total;
}


#define COMMON_FIXUP()                          \
    *pfmt++ = tchar;                            \
    if (pfmt >= pfmt_lim)                       \
        break;                                  \
    *pfmt = '\0';                               \
    fmt_size = (size_t)(pfmt - fmt)

#define CALL_SNPRINTF(TYPE, VALUE)                  \
    if (saw_width && saw_precision)                 \
        lcount = snprintf (pout, limit, fmt, width, \
                           precision, (TYPE)VALUE); \
    else if (saw_width)                             \
        lcount = snprintf (pout, limit, fmt, width, \
                           (TYPE)VALUE);            \
    else if (saw_precision)                         \
        lcount = snprintf (pout, limit, fmt,        \
                           precision, (TYPE)VALUE); \
    else                                            \
        lcount = snprintf (pout, limit, fmt,        \
                           (TYPE)VALUE)

LIB_EXPORT size_t CC knprintf ( char * buf, size_t bufsize, const char * fmt, ... )
{
    va_list ap;
    size_t used;
    va_start (ap, fmt);
    used = vknprintf (buf, bufsize, fmt, ap);
    va_end (ap);
    return used;
}

LIB_EXPORT size_t CC vknprintf ( char * buf, size_t bufsize, const char * str, va_list ap )
{
    char * pout;
    const char * pin;
    int * n_fmt;
    size_t count;
    int lcount;
    size_t limit;

    /*
     * superficial evaluation of parameters
     */
    assert (buf);
    assert (str);
    assert (bufsize);

    pout = buf;
    pin = str;
    count = 0;
    limit = bufsize;

    while (*pin)
    {
        /* abort if buf is full */
        if (limit == 0)
            break;

        /* handle characters up to a single '%' */
        count += normal_chars (&pout, &pin, &limit);

        /* can't break this out into a function because of the va_list */
        if (*pin == '%') /* same result as (*pin != '\0') if all is right */
        {
            size_t fmt_size;
            char * pfmt;
            char * pfmt_lim;
            void * pvalue;

            uint64_t livalue;
            double    fvalue;
            long double dfvalue;

            /* values for '*' for width and/or precision */
            int width;
            int precision;
            int ivalue;

            char fmt [128]; /* big enough?  I'd sure hope so. */

            char tchar;

            bool saw_width;
            bool saw_precision;
            bool saw_dot;
            bool saw_fmt_error;
            bool is_64_bit;

            saw_width = saw_precision = saw_dot = saw_fmt_error = false;
            width = 0;  /* shouldn't need this. */
            precision = 0; /* scary- will do the wrong thing is badly set */
            pfmt = fmt;
            pfmt_lim = pfmt + sizeof (fmt);

            *pfmt++ = '%';
            pin++;
            while (pfmt < pfmt_lim)
            {
                switch (tchar = *pin++)
                {
                case '\0':
                    goto fmt_error;

                default:
                    /*
                     * We catch some errors:
                     * characters we just don't recognize as
                     * belonging 
                     */
                    saw_fmt_error = true;
                    *pfmt++ = tchar;
                    goto fmt_error;

                    /* legal characters to be in a conversion specification
                     * we are not parsing except for gros errors like
                     * completely wrong characters
                     */
                    /* legal in flags but not processed here */
                case '+':
                case '-':
                case '#':

                    /* legal in width and precision (and 0 is flags) but not processed here */
                case '0':
                case '1':
                case '2':
                case '3':
                case '4':
                case '5':
                case '6':
                case '7':
                case '8':
                case '9':

                    /* legal in length but not processed here */
                case 'h':
                case 'l':
                case 'L':
                case 'z':
                case 'j':
                case 't':

                    /* most chars just get put in the format */
                    *pfmt++ = tchar;
                    break;

                case '.':
                    /* a period in a conversion specifier is a 
                     * switch to the optional precision portion
                     */
                    if (saw_dot)
                        saw_fmt_error = true;
                    else
                        saw_dot = true;
                    *pfmt++ = tchar;
                    break;

                case '*':
                    /* are we in the precision or the width? */
                    if (saw_dot)
                    {
                        /* had we seen this already? */
                        if (saw_precision)
                            saw_fmt_error = true;
                        else
                        {
                            saw_precision = true;
                            precision = va_arg (ap, int);
                        }
                    }
                    else
                    {
                        /* had we seen this already? */
                        if (saw_width)
                            saw_fmt_error = true;
                        else
                        {
                            saw_width = true;
                            width = va_arg (ap, int);
                        }
                    }
                    *pfmt++ = tchar;
                    break;


                    /* -----
                     * standard output conversion types
                     */

                    /*
                     * standard integer type formats: 
                     * we don't care about  base or signedness here
                     */
                case 'd':
                case 'i':
                case 'u':
                case 'X':
                case 'x':
                case 'o':

                    COMMON_FIXUP();

                    is_64_bit = false;
                    if (fmt_size > 2) /* long enough to have a length character */
                    {
                        switch (*(pfmt-2))
                        {
                        case 'j':
                        case 'l': /* for this project 'l' implies 64-bit not long int */
#if (_ARCH_BITS == 64)
                        case 't':
                        case 'z':
#endif
                            is_64_bit = true;
                            break;
                        default:
                            break;
                        }
                    }

                    /*
                     * this likely will do something only for windows
                     * unless we find flaws in the Unix-like
                     */
                    print_int_fixup (fmt, &fmt_size, limit);

                    if (is_64_bit)
                    {
                        uint64_t value;
                        value = va_arg (ap, uint64_t);
                        CALL_SNPRINTF (uint64_t, value);
                    }
                    else
                    {
                        int value;
                        value = va_arg (ap, int);
                        CALL_SNPRINTF (int, value);
                    }
                    goto handle_snprintf_return;

                    /*
                     * standard float type formats
                     */
                case 'e':
                case 'f':
                case 'g':
                case 'E':
                case 'F':
                case 'G':

                    COMMON_FIXUP();

                    print_float_fixup (fmt, &fmt_size, limit);

                    if ((fmt_size > 2) && (*(pfmt-2) == 'L'))
                    {
                        dfvalue = va_arg (ap, long double);
                        CALL_SNPRINTF (long double, dfvalue);
                    }
                    else
                    {
                        fvalue = va_arg (ap, double);
                        CALL_SNPRINTF (double, fvalue);
                    }
                    goto handle_snprintf_return;
                    
                    /*
                     * standard char type formats
                     */
                case 'c':

                    COMMON_FIXUP();

                    print_char_fixup (fmt, &fmt_size, limit);

                    tchar = (char)va_arg (ap, int);
                    if (fmt_size == 2)
                    {
                        *pout = tchar;
                        lcount = 1;
                    }
                    else
                    {
                        CALL_SNPRINTF (char, tchar);
                    }
                    goto handle_snprintf_return;
                    break;

                    /*
                     * standard pointer type formats
                     */
                case 'p':
                case 's':

                    COMMON_FIXUP();

                    pvalue = va_arg (ap, void *);
                    CALL_SNPRINTF (void*, pvalue);

                handle_snprintf_return:
                    if (lcount < 0)
                    {
                        saw_fmt_error = true;
                        break;
                    }
                    
                    count += lcount;
                    pout += lcount;

                    if (count >= bufsize)
                    {
                        limit = 0;
                        count = bufsize;
                        pout = buf+bufsize;
                    }
                    goto fmt_done;

                case 'n':
                    n_fmt = va_arg (ap, void*);
                    *n_fmt = (int)count;
                    goto fmt_done;

                    /* -----
                     * our extended output conversion types
                     */
                case 'V':
                    tchar = 's';
                    COMMON_FIXUP();
                    {
                        size_t v_count;
                        ver_t v;
                        char v_buf [128];
                        const char *version_fmt;

                        v = va_arg (ap, ver_t);

                        if ( ( v & 0xFFFF ) != 0 )
                            version_fmt ="%u.%u.%u";
                        else if ( ( v & 0xFF0000 ) != 0 )
                            version_fmt ="%u.%u";
                        else
                            version_fmt ="%u";

                        v_count = snprintf(v_buf, sizeof v_buf,
                                           version_fmt,
                                           VersionGetMajor(v),
                                           VersionGetMinor(v),
                                           VersionGetRelease(v));

                        assert (v_count < sizeof (v_buf));

                        CALL_SNPRINTF (const char *, v_buf);
                    }
                    goto handle_snprintf_return;

                case 'R': /* rc_t */
                    tchar = 's';
                    COMMON_FIXUP();
                    {
                        size_t ix;
                        size_t rc_count;
                        rc_t rc_in;
                        bool alternative;
                        char rc_buf [128];

                        rc_in = va_arg (ap, rc_t);


                        alternative = false;
                        for (ix = 1; ix < fmt_size; ++ix)
                        {
                            if (alternative)
                                fmt[ix-1] = fmt[ix];
                            else if (fmt[ix] == '#')
                                alternative = true;
                        }
                        if (alternative)
                        {
                            fmt_size --;
                            rc_count = KWrtFmt_rc_t (rc_buf, sizeof (rc_buf), "#", rc_in);
                        }
                        else
                            rc_count = KWrtFmt_rc_t (rc_buf, sizeof (rc_buf), "", rc_in);

                        assert (rc_count < sizeof (rc_buf));

                        CALL_SNPRINTF (const char *, rc_buf);
                    }
                    goto handle_snprintf_return;

                case 'N': /* KSymbol */
                    tchar = 's';
                    COMMON_FIXUP();
                    {
                        size_t rc_count;
                        KSymbol * sym_in;
                        char sym_buf [1024];

                        sym_in = va_arg (ap, void *);
                        
                        rc_count = KWrtFmt_KSymbol (sym_buf, sizeof (sym_buf), NULL, sym_in);

                        assert (rc_count < sizeof (sym_buf));

                        CALL_SNPRINTF (const char *, sym_buf);
                    }
                    goto handle_snprintf_return;

                case 'S': /* String */

                    tchar = 's';
                    COMMON_FIXUP();

                    {
                        String * s_str;
                        int s_count;
                        char s_buf [4000];
                        s_str = va_arg (ap, String *);
                        s_count = (int)s_str->size;
                        s_count = snprintf (s_buf, sizeof (s_buf), "%.*s", s_count, s_str->addr);

                        CALL_SNPRINTF (const char *, s_buf);
                    }
                    goto handle_snprintf_return;


                case '(':
                {
                    char lbuf [8192];
                    char lfmt [128];
                    char * plfmt;
/*                     char * plfmt_lim; */
                    const char * close_lfmt;
                    size_t fsize;
                    size_t lsize;
                    size_t(* pfunc)(char *, size_t, const char *, ...);

                    plfmt = lfmt;
/*                     plfmt_lim = lfmt + sizeof (lfmt); */
                    *plfmt++ = '(';

                    close_lfmt = strchrnul (pin, ')');

                    if (*close_lfmt != ')')
                    {
                        /* no close */
                    error_lfmt:
                        fsize = (size_t)(close_lfmt - --pin);
                        pin += copy_chars (&pfmt, &fsize, pin, fsize);
                        saw_fmt_error = true;
                        if (pfmt == pfmt_lim)
                            --pfmt;
                        break;
                    }
                    else
                    {
                        /* got a close */
                        const char * pin_save;
                        
                        pin_save = pin;

                        lsize = sizeof (lbuf-1);
                        fsize = (size_t)(close_lfmt - pin);
                            
                        pin += copy_chars (&plfmt, &lsize, pin, fsize);

                        *plfmt = '\0';

                        /* pop off function pointer */
                        pfunc = va_arg (ap, pfunc_t);

                        switch (tchar = *pin++)
                        {
                            /* we only accept a limited set of types
                             * after ():
                             * int (i), 64bit int (ji), double (f),
                             * long double (Lf), and pointer (p)
                             */
                        default:
                        error_l_type:
                            close_lfmt = pin;
                            pin = pin_save;
                            goto error_lfmt;

                        case 'p':
                            pvalue = va_arg (ap, void*);
                            lsize = pfunc (lbuf+1, sizeof (lbuf), lfmt, pvalue);
                            break;

                        case 'i':
                        case 'd':
                        case 'o':
                        case 'x':
                        case 'X':
                            ivalue = va_arg (ap, int);
                            lsize = pfunc (lbuf+1, sizeof (lbuf), lfmt, ivalue);
                            break;

                        case 'e':
                        case 'f':
                        case 'g':
                        case 'E':
                        case 'F':
                        case 'G':
                            fvalue = va_arg (ap, double);
                            lsize = pfunc (lbuf+1, sizeof (lbuf), lfmt, fvalue);
                            break;

                        case 'j':
                        case 'l':
                            tchar = *pin++;
                            switch (tchar)
                            {
                            default:
                                goto error_l_type;
                            case 'i':
                            case 'd':
                            case 'o':
                            case 'x':
                            case 'X':
                                break;
                            }
                            livalue = va_arg (ap, uint64_t);
                            lsize = pfunc (lbuf+1, sizeof (lbuf), lfmt, livalue);
                            break;

                        case 'L':
                            tchar = *pin++;
                            switch (tchar)
                            {
                            default:
                                goto error_l_type;
                            case 'e':
                            case 'f':
                            case 'g':
                            case 'E':
                            case 'F':
                            case 'G':
                                break;
                            }
                            dfvalue = va_arg (ap, long double);
                            lsize = pfunc (lbuf+1, sizeof (lbuf), lfmt, dfvalue);
                            break;

                        } /* switch (tchar = *pin++) */

                        tchar = 's';
                        COMMON_FIXUP();
                        CALL_SNPRINTF (char *, lbuf);
                        goto handle_snprintf_return;
                    } /* after COMMON_FIXUP() */
                } /* case '(': */
                } /* switch (tchar = *pin++) */
            } /* while (pfmt < pfmt_lim) */
            if (pfmt >= pfmt_lim)
            {
                /* if we get here we had a format overflow */
                static const char msg[] = "<FMT-OVERFLOW>\n";
                count += copy_chars (&pout, &limit, msg, sizeof (msg));
                count += copy_chars (&pout, &limit, fmt, (size_t)(pfmt_lim-fmt));
            }
            else if (saw_fmt_error)
            {
                /* if we get here we had a format overflow */
                static const char msg[] = "<FMT-ERROR>\n";

            fmt_error:
                count += copy_chars (&pout, &limit, fmt, (size_t)(pfmt-fmt));
                count += copy_chars (&pout, &limit, msg, sizeof (msg));
                if (limit > 0)
                    *pout = '\0';
                else
                    *--pout = '\0';
                break;
            }
        fmt_done:
            ;
        } /* if (*pin == '%') */
    } /* while (*pin) */
    return count;
}


/* EOF */
