import profile
import pstats
import os
import os.path as osp


def profile_func(output_dir=None, filename=None, mode="w",
				 strip_dirs=False, sort_by=None):
	"""
	Decorator factory function.
	
	PARAMS:
	output_dir  -- stats printed in this dir
	filename	-- name of stats file. Default is name of function with postfix
				   '_func.txt'.
	mode		-- open the file with this mode. Default is 'w' (writing).
	strip_dirs  -- if true strip dirnames in stats output. Default is false.
	sort_by	 -- Sort the stats by given name. Default is no sorting.
				   Options:
				   - calls
				   - cumuluative
				   - file
				   - line
				   - module
				   - nfl (name/file/lines)
				   - pcalls
				   - stdname
				   - time
				   
	"""

	# decorator
	def proffunc(f):
		out_filename = filename or "%s_func.txt" % f.__name__
		if output_dir:
			if not osp.exists(output_dir):
				os.makedirs(output_dir)
			out_path = osp.join(output_dir, out_filename)
		else:
			# current dir
			out_path = out_filename

		# decorator
		def profiled_func(*args, **kwargs):
			try:
				out_file = open(out_path, mode)
			except IOError, (errno, strerror):
				print "I/O error(%s): %s" % (errno, strerror)
				return f(*args, **kwargs)
			
			profiler = profile.Profile()
			retval = profiler.runcall(f, *args, **kwargs)
			
			stats = pstats.Stats(profiler)
			if strip_dirs:
				stats.strip_dirs()

			if sort_by:
				stats.sort_stats(sort_by)
			
			print_stats(stats, out_file)
			out_file.close()
			
			return retval

		profiled_func.__name__ == f.__name__
		
		return profiled_func


	return proffunc


####################
# Helper functions
####################

def get_print_list(stats, out_file, sel_list):
	width = stats.max_name_len
	if stats.fcn_list:
		list = stats.fcn_list[:]
		msg = "   Ordered by: " + stats.sort_type + '\n'
	else:
		list = stats.stats.keys()
		msg = "   Random listing order was used\n"

	for selection in sel_list:
		list, msg = stats.eval_print_amount(selection, list, msg)

	count = len(list)

	if not list:
		return 0, list
	print >> out_file, msg
	if count < len(stats.stats):
		width = 0
		for func in list:
			if  len(pstats.func_std_string(func)) > width:
				width = len(pstats.func_std_string(func))
	return width+2, list

def print_line(stats, out_file, func):  # hack : should print percentages
	cc, nc, tt, ct, callers = stats.stats[func]
	c = str(nc)
	if nc != cc:
		c = c + '/' + str(cc)
	print >> out_file, c.rjust(9),
	print >> out_file, pstats.f8(tt),
	if nc == 0:
		print >> out_file, ' '*8,
	else:
		print >> out_file, pstats.f8(tt/nc),
	print >> out_file, pstats.f8(ct),
	if cc == 0:
		print >> out_file, ' '*8,
	else:
		print >> out_file, pstats.f8(ct/cc),
	print >> out_file, pstats.func_std_string(func)

def print_stats(stats, out_file, *amount):
	for filename in stats.files:
		print >> out_file, filename
	if stats.files: print >> openfile, ""
	indent = ' ' * 8
	for func in stats.top_level:
		print >> out_file, indent, pstats.func_get_function_name(func)

	print >> out_file, indent, stats.total_calls, "function calls",
	if stats.total_calls != stats.prim_calls:
		print >> out_file, "(%d primitive calls)" % stats.prim_calls,
	print >> out_file, "in %.3f CPU seconds" % stats.total_tt
	print >> out_file, ""
	width, list = get_print_list(stats, out_file, amount)
	if list:
		print >> out_file, '   ncalls  tottime  percall  cumtime  percall', \
			  'filename:lineno(function)'
		for func in list:
			print_line(stats, out_file, func)
		print >> out_file, ""
		print >> out_file, ""  