/* Copyright (C) 2009 Keith Crane

This file is part DFILE Tools.

DFILE Tools is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or (at
your option) any later version.

DFILE Tools is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License along
with DFILE Tools; see the file COPYING.  If not, see
<http://www.gnu.org/licenses/>. */

#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <assert.h>
#include "tbox.h"

static const char       rcsid[] = "$Id: merge_sort.c,v 1.2 2009/10/16 18:00:43 keith Exp $";

/*
** $Log: merge_sort.c,v $
** Revision 1.2  2009/10/16 18:00:43  keith
** Added GPL to source code.
**
** Revision 1.1  2009/02/14 18:19:49  keith
** Initial revision
**
*/

static int ( *cmp_func )( const void *, const void * );
static int desc_cmp( const void *, const void * );
static void init_sort( char *, size_t, size_t, int (*)( const void *, const void * ), char *, size_t );
void internal_merge_sort( char *, size_t, size_t, int ( * )( const void *, const void * ), char *, long );

/*
** This function sorts an array using merge sort.
*/

int merge_sort( void *base, size_t element_cnt, size_t element_size, int ( *cmp )(const void *, const void * ) )
{
	static const char func[] = "merge_sort";
	void 	*work;
	const long	skip_init_passes = 4; /* needs to be an even number */

	assert( base != (void *)0 );
	assert( cmp != (int (*)(const void *, const void * ) )0 );

	if ( Debug ) {
		(void) fprintf( stderr, "%s( %p, %u, %u, %p )\n", func, base, element_cnt, element_size, cmp );
	}

	DEBUG_FUNC_START;

	if ( element_cnt < (size_t)2 ) {
		if ( Debug ) {
			(void) fputs( "Not enough elements to sort.\n", stderr );
		}
		RETURN_INT( 0 );
	}

	if ( element_size == (size_t)0 ) {
		if ( Debug ) {
			(void) fputs( "Element size was zero.\n", stderr );
		}
		RETURN_INT( -1 );
	}

	work = malloc( element_size * element_cnt );
	if ( work == (void *)0 ) {
		UNIX_ERROR( "malloc() failed" );
		RETURN_INT( -1 );
	}

	init_sort( (char *)base, element_cnt, element_size, cmp, (char *)work, skip_init_passes );

	internal_merge_sort( (char *)base, element_cnt, element_size, cmp, (char *)work, skip_init_passes );

	free( work );

	RETURN_INT( 0 );
}

void internal_merge_sort( char *base, size_t element_cnt, size_t element_size, int ( *cmp )( const void *, const void * ), char *work, long skip_init_passes )
{
	char	*swap_output_ptr;
	char	*lhs_ptr, *rhs_ptr, *output_ptr, *save_output_ptr;
	char	*init_base_rhs_ptr, *init_work_rhs_ptr;
	char	*init_base_output_ptr, *init_work_output_ptr;
	char	*init_base_save_output_ptr, *init_work_save_output_ptr;
	size_t	bytes_to_be_sorted;

	int	direction_flag;	/* values are { 0, 1 } */

	long	output_ptr_increment; /* values are { +/- element_size } */

	long	merge_size, lhs_merge_cnt, rhs_merge_cnt;

	bytes_to_be_sorted = element_cnt * element_size;
	init_base_rhs_ptr = base + ( bytes_to_be_sorted - element_size );
	init_work_rhs_ptr = work + ( bytes_to_be_sorted - element_size );
	init_base_output_ptr = base - element_size;
	init_work_output_ptr = work - element_size;
	init_base_save_output_ptr = base + bytes_to_be_sorted;
	init_work_save_output_ptr = work + bytes_to_be_sorted;


/* S1: */
	direction_flag = 0;

	merge_size = 1 << skip_init_passes;

	if ( merge_size >= element_cnt ) {
		/*
		** Small array already sorted.
		*/
		return;
	}

S2:
	if ( direction_flag == 0 ) {
		lhs_ptr = base;
		rhs_ptr = init_base_rhs_ptr;
		output_ptr = init_work_output_ptr;
		save_output_ptr = init_work_save_output_ptr;
	} else {
		assert( direction_flag == 1 );
		lhs_ptr = work;
		rhs_ptr = init_work_rhs_ptr;
		output_ptr = init_base_output_ptr;
		save_output_ptr = init_base_save_output_ptr;
	}

	output_ptr_increment = element_size;
	lhs_merge_cnt = merge_size;
	rhs_merge_cnt = merge_size;

S3:
	if ( ( *cmp )( (void *)lhs_ptr, (void *)rhs_ptr ) > 0 ) {
		goto S8;
	}

/* S4: */
	output_ptr += output_ptr_increment;
	(void) memcpy( (void *)output_ptr, (void *)lhs_ptr, element_size );

/* S5: */
	lhs_ptr += element_size;
	--lhs_merge_cnt;

	if ( lhs_merge_cnt > 0L ) {
		goto S3;
	}

S6:
	output_ptr += output_ptr_increment;
	if ( output_ptr == save_output_ptr ) {
		goto S13;
	}

	(void) memcpy( (void *)output_ptr, (void *)rhs_ptr, element_size );

/* S7: */
	rhs_ptr -= element_size;
	--rhs_merge_cnt;
	if ( rhs_merge_cnt > 0L ) {
		goto S6;
	}

	goto S12;

S8:
	output_ptr += output_ptr_increment;
	(void) memcpy( (void *)output_ptr, (void *)rhs_ptr, element_size );

/* S9: */
	rhs_ptr -= element_size;
	--rhs_merge_cnt;
	if ( rhs_merge_cnt > 0L ) {
		goto S3;
	}

S10:
	output_ptr += output_ptr_increment;
	if ( output_ptr == save_output_ptr ) {
		goto S13;
	}

	(void) memcpy( (void *)output_ptr, (void *)lhs_ptr, element_size );

/* S11: */
	lhs_ptr += element_size;
	--lhs_merge_cnt;
	if ( lhs_merge_cnt > 0L ) {
		goto S10;
	}

S12:
	lhs_merge_cnt = merge_size;
	rhs_merge_cnt = merge_size;
	output_ptr_increment = -output_ptr_increment;
	swap_output_ptr = output_ptr;
	output_ptr = save_output_ptr;
	save_output_ptr = swap_output_ptr;
	if ( lhs_ptr > rhs_ptr || rhs_ptr - lhs_ptr < merge_size * element_size ) {
		goto S10;
	}

	goto S3;

S13:
	merge_size <<= 1;
	if ( merge_size < element_cnt ) {
		direction_flag = 1 - direction_flag;
		goto S2;
	}

	if ( direction_flag == 0 ) {
		(void) memcpy( (void *)base, work, bytes_to_be_sorted );
	}
}

/*
** This function does an initial insertion sort on small merge partitions.
*/
static void init_sort( char *base, size_t element_cnt, size_t element_size, int ( *cmp )( const void *, const void * ), char *save, size_t passes )
{
	size_t	init_partition_size, ndx, bytes_per_partition;
	size_t	init_partition_cnt, partitions_per_side, middle_element_cnt;
	char	*ptr;

	init_partition_size = 1 << passes;
	init_partition_cnt = element_cnt / init_partition_size;
	partitions_per_side = init_partition_cnt >> 1; /* divide by 2 */

	middle_element_cnt = element_cnt - ( init_partition_cnt * init_partition_size );

	cmp_func = cmp;

	bytes_per_partition = init_partition_size * element_size;

	ptr = base;
	for ( ndx = (size_t)0; ndx < partitions_per_side; ++ndx ) {
		internal_insertion_sort( ptr, init_partition_size, element_size, cmp, save );
		ptr += bytes_per_partition;
	}

	if ( init_partition_cnt & (size_t)1 ) {
		/*
		** odd number of partitions
		*/
		internal_insertion_sort( ptr, init_partition_size, element_size, cmp, save );
		ptr += bytes_per_partition;
		internal_insertion_sort( ptr, middle_element_cnt, element_size, desc_cmp, save );
		ptr += middle_element_cnt * element_size;
	} else {
		/*
		** even number of partitions
		*/
		internal_insertion_sort( ptr, middle_element_cnt, element_size, cmp, save );
		ptr += middle_element_cnt * element_size;
	}

	for ( ndx = (size_t)0; ndx < partitions_per_side; ++ndx ) {
		internal_insertion_sort( ptr, init_partition_size, element_size, desc_cmp, save );
		ptr += bytes_per_partition;
	}
}

static int desc_cmp( const void *x, const void *y )
{
	return -( *cmp_func )( x, y );
}

#ifdef MT_merge_sort

#include <stdlib.h>
/*
** This function is used to regression test merge_sort().
** The following command is used to compile:
**   x=merge_sort; make "MT_CC=-DMT_$x" $x
*/

#define	SORT_CNT	1000000

int long_cmp( const void *x, const void *y )
{
	return *(long *)x - *(long *)y;
}

int main( void )

{
	static const char	complete_msg[] =  ">>> Module test on function %s() is complete.\n";
	static const char	test_func[] = "merge_sort";
	static const char	successful[] = ">>>\n>>> %s() was successful.\n";
	static const char	unsuccessful[] = ">>>\n>>> %s() was unsuccessful.\n";
	static const char	blank_line[] = ">>>\n";
	long	x[SORT_CNT];
	unsigned long	ndx;
	int	ret;

	Debug = 0;

	for ( ndx = (unsigned short)0; ndx < SORT_CNT; ++ndx ) {
		x[ ndx ] = random() % SORT_CNT;
	}

/*
	(void) fprintf( stderr, ">>> Start module test on function %s().\n", test_func );
	(void) fputs( blank_line, stderr );
	(void) fputs( ">>> TEST CASE #1\n", stderr );
	(void) fputs( ">>> Sort the following integers.\n", stderr );
	(void) fputs( blank_line, stderr );

	(void) fputs( "UNSORTED:\n", stderr );
	for ( ndx = (unsigned short)0; ndx < SORT_CNT; ++ndx ) {
		(void) fprintf( stderr, "%7ld", x[ ndx ] );
	}
	(void) fputc( '\n', stderr );
*/

	ret = merge_sort( (void *)x, SORT_CNT, sizeof( long ), long_cmp );
	if ( ret == -1 ) {
		(void) fprintf( stderr, unsuccessful, test_func );
		return 1;
	}
	qsort( (void *)x, SORT_CNT, sizeof( long ), long_cmp );

/*
	(void) fputs( blank_line, stderr );
	(void) fputs( "SORTED:\n", stderr );
	for ( ndx = (unsigned short)0; ndx < SORT_CNT; ++ndx ) {
		(void) fprintf( stderr, "%7ld", x[ ndx ] );
	}
	(void) fputc( '\n', stderr );
*/

/*
	for ( ndx = (unsigned short)0; ndx < ( SORT_CNT - 1 ); ++ndx ) {
		if ( x[ ndx ] > x[ ndx + 1 ] ) {
			(void) fprintf( stderr, unsuccessful, test_func );
			return 1;
		}
	}
*/

	(void) fprintf( stderr, successful, test_func );

	(void) fputs( blank_line, stderr );
	(void) fprintf( stderr, complete_msg, test_func );
	exit( 0 );
}
#endif
