/* Copyright (C) 2009 Keith Crane

This file is part DFILE Tools.

DFILE Tools is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or (at
your option) any later version.

DFILE Tools is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License along
with DFILE Tools; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>. */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include "tbox.h"

static const char       rcsid[] = "$Id: heap_sort.c,v 1.2 2009/10/16 18:00:43 keith Exp $";

/*
** $Log: heap_sort.c,v $
** Revision 1.2  2009/10/16 18:00:43  keith
** Added GPL to source code.
**
** Revision 1.1  2009/02/14 09:25:38  keith
** Initial revision
**
*/

void siftup( char *, size_t, size_t, size_t, char *, void *, int ( * )( const void *, const void * ) );

/*
** This function sorts an array using heap sort.
*/

int heap_sort( void *base, size_t element_cnt, size_t element_size, int ( *cmp )(const void *, const void * ) )
{
	static const char func[] = "heap_sort";
	size_t	ndx;
	char	*base_shift, *ptr;
	void 	*save;

	assert( base != (void *)0 );
	assert( cmp != (int (*)(const void *, const void * ) )0 );

	if ( Debug ) {
		(void) fprintf( stderr, "%s( %p, %u, %u, %p )\n", func, base, element_cnt, element_size, cmp );
	}

	DEBUG_FUNC_START;

	if ( element_cnt < (size_t)2 ) {
		if ( Debug ) {
			(void) fputs( "Not enough elements to sort.\n", stderr );
		}
		RETURN_INT( 0 );
	}

	if ( element_size == (size_t)0 ) {
		if ( Debug ) {
			(void) fputs( "Element size was zero.\n", stderr );
		}
		RETURN_INT( -1 );
	}

	save = malloc( element_size );
	if ( save == (void *)0 ) {
		UNIX_ERROR( "malloc() failed" );
		RETURN_INT( -1 );
	}

	/*
	** Shift base address for indexing to start at one instead of zero.
	*/
	base_shift = ( char *)base - element_size;

	/*
	** Create heap.
	**
	** ( element_cnt / (size_t)2 ) + (size_t)1;
	*/
	ndx = ( element_cnt >> 1 ) + (size_t)1;
	ptr = base_shift + ( ndx * element_size );

	while ( ndx > (size_t)1 ) {
		--ndx;
		ptr -= element_size;
		(void) memcpy( save, (void *)ptr, element_size );
		siftup( base_shift, element_size, ndx, element_cnt, ptr, save, cmp );
	}

	ptr = base_shift + ( element_cnt * element_size );
	(void) memcpy( save, (void *)ptr, element_size );
	(void) memcpy( ptr, (void *)base, element_size );

	for ( ndx = element_cnt - (size_t)1; ndx > (size_t)1; --ndx ) {
		siftup( base_shift, element_size, (size_t)1, ndx, (char *)base, save, cmp );
		ptr -= element_size;
		(void) memcpy( save, (void *)ptr, element_size );
		(void) memcpy( (void *)ptr, base, element_size );
	}

	(void) memcpy( base, save, element_size );

	free( save );

	RETURN_INT( 0 );
}

void siftup( char *base, size_t element_size, size_t j, size_t r, char *previous, void *save, int ( *cmp )( const void *, const void * ) )
{
	char	*key;
	int	ret;
	size_t	l, i;

	l = j;

	/*
	** j *= 2;
	*/
	j <<= 1;

	key = base + ( j * element_size );

	while ( r >= j ) {
		if ( r > j ) {
			/*
			** Find larger child.
			*/
			ret = ( *cmp )( (void *)( key + element_size ), (void *)key );
			if ( ret > 0 ) {
				++j;
				key += element_size;
			}
		}

#if 0
		/*
		** H6
		** Larger than key?
		*/
		ret = ( *cmp )( save, (void *)key );
		if ( ret > 0 ) {
			break;
		}
#endif

		/*
		** Move it up.
		*/
		(void) memcpy( (void *)previous, key, element_size );

		previous = key;

		/*
		** j *= 2;
		*/
		j <<= 1;

		key = base + ( j * element_size );
	}

#if 0
	/*
	** H8
	*/
	(void) memcpy( (void *)previous, save, element_size );
#endif

	/*
	** Remaining code in this function replaced H6 and H8 to reduce
	** number of comparisons.
	*/
	assert( previous >= base );
	i = (size_t)( previous - base ) / element_size;

	for ( ;; ) {
		j = i;

		if ( j == l ) {
			break;
		}

		/*
		** i = j /= 2;
		*/
		i = j >> 1;

		key = base + ( i * element_size );
		ret = ( *cmp )( (void *)key, save );
		if ( ret >= 0 ) {
			break;
		}
		(void) memcpy( (void *)previous, key, element_size );
		previous = key;
	}

	(void) memcpy( (void *)previous, save, element_size );

	return;
}

#ifdef MT_heap_sort

#include <stdlib.h>
/*
** This function is used to regression test heap_sort().
** The following command is used to compile:
**   x=heap_sort; make "MT_CC=-DMT_$x" $x
*/

#define	SORT_CNT	10000

int long_cmp( const void *x, const void *y )
{
fprintf( stderr, "compared\n" );
	return *(long *)x - *(long *)y;
}

int main( void )

{
	static const char	complete_msg[] =  ">>> Module test on function %s() is complete.\n";
	static const char	test_func[] = "heap_sort";
	static const char	successful[] = ">>>\n>>> %s() was successful.\n";
	static const char	unsuccessful[] = ">>>\n>>> %s() was unsuccessful.\n";
	static const char	blank_line[] = ">>>\n";
	long	x[SORT_CNT];
	unsigned short	ndx;
	int	ret;

	Debug = 1;

	for ( ndx = (unsigned short)0; ndx < SORT_CNT; ++ndx ) {
		x[ ndx ] = random() % 10000;
	}

	(void) fprintf( stderr, ">>> Start module test on function %s().\n", test_func );
	(void) fputs( blank_line, stderr );
	(void) fputs( ">>> TEST CASE #1\n", stderr );
	(void) fputs( ">>> Sort the following integers.\n", stderr );
	(void) fputs( blank_line, stderr );

	(void) fputs( "UNSORTED:\n", stderr );
	for ( ndx = (unsigned short)0; ndx < SORT_CNT; ++ndx ) {
		(void) fprintf( stderr, "%7ld", x[ ndx ] );
	}
	(void) fputc( '\n', stderr );

	ret = heap_sort( (void *)x, SORT_CNT, sizeof( long ), long_cmp );
	if ( ret == -1 ) {
		(void) fprintf( stderr, unsuccessful, test_func );
		return 1;
	}

	(void) fputs( blank_line, stderr );
	(void) fputs( "SORTED:\n", stderr );
	for ( ndx = (unsigned short)0; ndx < SORT_CNT; ++ndx ) {
		(void) fprintf( stderr, "%7ld", x[ ndx ] );
	}
	(void) fputc( '\n', stderr );

	for ( ndx = (unsigned short)0; ndx < ( SORT_CNT - 1 ); ++ndx ) {
		if ( x[ ndx ] > x[ ndx + 1 ] ) {
			(void) fprintf( stderr, unsuccessful, test_func );
			return 1;
		}
	}

	(void) fprintf( stderr, successful, test_func );

	(void) fputs( blank_line, stderr );
	(void) fprintf( stderr, complete_msg, test_func );
	exit( 0 );
}
#endif
