#!/usr/bin/env python
'''
qcd_from_data <path to config.json> 

Configuration parameters

data_file: 
	path to data file, used to get the templates 

mc_file:
	path to MC file, used to normalisation of data histograms

subtract_other_samples:
	dictionary of (sample, file) to be removed from the template  control region (data_file)
	
histogram_path:
	source of histograms. All sub-paths will be considered (recursive)
	
ignore_subpaths:
	sub-paths of histogram_path to be ignored
	
normalisation_keyword_in_path:
	part of the path that determines the normalisation (usually signal) region
	
shape_keyword_in_path:
	part of the path that determines the template (usually control/sideband) region
	
shape_btag:
	b-tag multiplicity to be used for the template region

shape_btag_for_exceptions:
	b-tag multiplicity to be used for the template region for exceptions (see shape_btag_exceptions)
	
shape_btag_exceptions:
	list of histogram names that have a different b-tag multiplicity for the template region

remove_for_shape:
	part of the histogram name to be removed fromt the template histogram name. 
	Useful for re-weighted histograms (mc-only).

output_file:
	path to the output file
	
Uses the data_file to extract the templates, removes other samples 
(subtract_other_samples) and normalises it according to mc_file.
	
'''
from ROOT import gROOT
gcd = gROOT.cd
from optparse import OptionParser
from dps.utils.file_utilities import read_data_from_JSON
# @BROKEN
from dps.utils.ROOT_utililities import root_mkdir, find_btag, get_histogram_dictionary
from dps.utils.hist_utilities import clean_control_region
from rootpy.io import root_open

def main():
	print "Welcome to the QCD-from-data merging script"
	print 'Please take a seat while the code is being developed.'
	print 'Once finished you will be able to create a single file using shapes from data and normalisation from MC'
	print 'In the meantime have a look at the script usage'
	print
	options, input_values_sets, json_input_files = parse_options()
	if options.test:
		input_values_sets = [setup_test_values()]
		json_input_files = ['test.json']

	for input_values, json_file in zip(input_values_sets, json_input_files):
		print 'Processing', json_file
		create_qcd_file(input_values)
		
def parse_options():
    parser = OptionParser( __doc__ )
    parser.add_option( "-t", "--test", dest = "test", action = "store_true",
                      help = "Run with test values and write them to test.json" )
    ( options, args ) = parser.parse_args()
    
    input_values_sets = []
    json_input_files = []
    add_set = input_values_sets.append
    add_json_file = json_input_files.append
    if not options.test:
        for arg in args:
            input_values = read_data_from_JSON(arg)
            add_set(input_values)
            add_json_file(arg)

    return options, input_values_sets, json_input_files
   
def create_qcd_file(input_values):
	data_file = input_values['data_file']
	mc_file = input_values['mc_file']
	histogram_path = input_values['histogram_path']
	shape_keyword_in_path = input_values['shape_keyword_in_path']
	shape_btag = input_values['shape_btag']
	shape_btag_for_exceptions = input_values['shape_btag_for_exceptions']
	shape_btag_exceptions = input_values['shape_btag_exceptions']
	remove_for_shape = input_values['remove_for_shape']
	normalisation_keyword_in_path = input_values['normalisation_keyword_in_path']
	ignore_subpaths = input_values['ignore_subpaths']
	subtract_other_samples = input_values['subtract_other_samples']
	output_file = input_values['output_file']
	
	
	total_histograms = 0
	data_file_handle = root_open(data_file)
	get_shape_hist = data_file_handle.Get
	output = {}
	with root_open(mc_file) as f:
		for path,_,histograms in f.walk():
			ignore_path = False
			for subpath in ignore_subpaths:
				if subpath in path:
					ignore_path = True
			if not histogram_path in path or not histograms or ignore_path:
				continue 
			for histogram in histograms:
				hist = f.Get(path + '/' + histogram)
				normalisation = hist.integral(overflow = True)
				shape_path = path.replace(normalisation_keyword_in_path, shape_keyword_in_path)
				# now swap the b-tag
				current_btag, _ = find_btag(histogram)
				is_exception = False
				for var in shape_btag_exceptions:
					if var in histogram:
						is_exception = True
				shape_histogram = histogram
				for r in remove_for_shape:
					shape_histogram = shape_histogram.replace(r, '')
				if is_exception:
					shape_histogram = shape_histogram.replace(current_btag, shape_btag_for_exceptions)
				else:
					shape_histogram = shape_histogram.replace(current_btag, shape_btag)
				gcd()
				output_hist = get_shape_hist(shape_path + '/' + shape_histogram).clone()
				other_samples = get_histogram_dictionary(shape_path + '/' + shape_histogram, subtract_other_samples)
				subtract_samples = other_samples.keys()
				other_samples['data'] = output_hist
				output_hist = clean_control_region(other_samples, 
												subtract = subtract_samples)
				# scale the histogram
				n_entries_shape = output_hist.integral(overflow = True)
				scale_factor = 1
				if n_entries_shape > 0:
					if normalisation == 0:
						# bug fix for empty templates
						scale_factor = 1/n_entries_shape
					else:
						scale_factor = normalisation/n_entries_shape
				
				output_hist.Scale(scale_factor)
				output[path + '/' + histogram] = output_hist
			total_histograms += len(histograms)
	
	data_file_handle.close()
	output_file_handle = root_open(output_file, 'recreate')
	# probably faster to use TFileCache within the loop above.
	for path_with_hist, histogram in output.iteritems():
		histogram_name = path_with_hist.split('/')[-1]
		path = path_with_hist.replace('/' + histogram_name, '')
		root_mkdir(output_file_handle, path)
		output_file_handle.cd(path)
		histogram.write(histogram_name)
		output_file_handle.cd()
	output_file_handle.close()
	print 'Processed', total_histograms, 'histograms'

if __name__ == '__main__':
	main()
	