#!/usr/bin/env python3

import math

import basf2
import fei
import modularAnalysis as ma
from variables import variables
import variables.collections as vc
import variables.utils as vu
import vertex
from stdCharged import stdE, stdMu
import stdPi0s
import stdV0s

basf2.conditions.prepend_globaltag(ma.getAnalysisGlobaltag())
basf2.set_random_seed("BelleIISummerSchool2022")

# create path
main = basf2.create_path()

ma.inputMdst('/home/belle2/fmeier/BelleIISummerSchool2022/advanced.mdst.root', path=main)

# apply hadronic FEI
cut_on_Btag = 'abs(deltaE)<0.18 and Mbc > 5.27'
particles = fei.get_default_channels(hadronic=True, semileptonic=False, baryonic=True, B_extra_cut=cut_on_Btag)
configuration = fei.config.FeiConfiguration(prefix='FEIv4_2021_MC14_release_05_01_12', training=False, monitor=False, cache=0)
feistate = fei.get_path(particles, configuration)
main.add_path(feistate.path)

ma.rankByHighest('B+:generic', 'extraInfo(SignalProbability)', outputVariable='sigProb_rank', numBest=0, path=main)
ma.applyCuts('B+:generic', f'extraInfo(SignalProbability) > 0.001 and {cut_on_Btag}', path=main)

ma.applyEventCuts('nParticlesInList(B+:generic) >= 1', path=main)

# Pre-selection cut: require at least one good lepton in the event
LeptonSkimCut = 'dr < 0.5 and abs(dz) < 2 and p > 0.3'

ma.fillParticleList('mu+:loose', f'{LeptonSkimCut} and inKLMAcceptance', path=main)
# stdMu('FixedThresh05', method='bdt', classification='global', inputListName='mu+:loose', outputListLabel='sel', lid_weights_gt='leptonid_Moriond2022_Official_rel5_v1a', release=5, path=main)

ma.fillParticleList('e+:uncorrected', f'{LeptonSkimCut} and inCDCAcceptance', path=main)
fivedegrees = math.radians(5)
looseIPCut = 'dr < 0.5 and abs(dz) < 2'
variables.addAlias("goodFWDGamma", "passesCut(clusterReg == 1 and clusterE > 0.075)")
variables.addAlias("goodBRLGamma", "passesCut(clusterReg == 2 and clusterE > 0.05)")
variables.addAlias("goodBWDGamma", "passesCut(clusterReg == 3 and clusterE > 0.1)")
variables.addAlias('goodGamma', 'passesCut(goodFWDGamma or goodBRLGamma or goodBWDGamma)')
ma.fillParticleList('gamma:good', 'goodGamma', path=main)
ma.correctBremsBelle('e+:corrected', 'e+:uncorrected', 'gamma:good', multiplePhotons=False, angleThreshold=fivedegrees, path=main)
# stdE('FixedThresh09', method='likelihood', classification='binary', inputListName="e+:corrected", outputListLabel='sel', lid_weights_gt='leptonid_Moriond2022_Official_rel5_v1a', release=5, path=main)

ma.fillParticleList('K+:sel', f'{looseIPCut} and kaonID > 0.6', path=main)
ma.fillParticleList('pi+:sel', f'{looseIPCut} and pionID > 0.5', path=main)

stdPi0s.stdPi0s('eff50_May2020' , path=main)
stdV0s.stdKshorts(path=main)

# reconstruct D0 in many different modes
ma.reconstructDecay('D0:Kpi -> K-:sel pi+:sel', 'abs(dM) < 0.015', dmID=0, path=main)
ma.reconstructDecay('D0:K2pi -> K-:sel pi+:sel pi0:eff50_May2020', 'abs(dM) < 0.025', dmID=1, path=main)
ma.reconstructDecay('D0:K3pi -> K-:sel pi+:sel pi-:sel pi+:sel', 'abs(dM) < 0.015', dmID=2, path=main)
ma.reconstructDecay('D0:KS2pi -> K_S0:merged pi+:sel pi-:sel', 'abs(dM) < 0.015', dmID=3, path=main)
ma.reconstructDecay('D0:KK -> K-:sel K+:sel', 'abs(dM) < 0.015', dmID=4, path=main)
ma.reconstructDecay('D0:KSpi -> K_S0:merged pi0:eff50_May2020', 'abs(dM) < 0.025', dmID=5, path=main)
ma.reconstructDecay('D0:KS3pi -> K_S0:merged pi+:sel pi-:sel pi0:eff50_May2020', 'abs(dM) < 0.025', dmID=6, path=main)
ma.reconstructDecay('D0:pipi -> pi+:sel pi-:sel', 'abs(dM) < 0.015', dmID=7, path=main)
ma.reconstructDecay('D0:K4pi -> K-:sel pi+:sel pi+:sel pi-:sel pi0:eff50_May2020', 'abs(dM) < 0.025', dmID=8, path=main)
ma.reconstructDecay('D0:3pi -> pi+:sel pi-:sel pi0:eff50_May2020', 'abs(dM) < 0.025', dmID=9, path=main)
ma.copyLists('D0:all', ['D0:Kpi', 'D0:K2pi', 'D0:K3pi', 'D0:KS2pi', 'D0:KK',
                        'D0:KS3pi', 'D0:pipi', 'D0:K4pi', 'D0:3pi'], path=main)

ma.variablesToExtraInfo('D0:all', {'InvM': 'D_BFInvM', 'M': 'D_BFM'}, path=main)

# reconstruct eta
ma.reconstructDecay('eta:gg -> gamma:good gamma:good', 'abs(dM) < 0.07', dmID=0, path=main)
ma.variablesToExtraInfo('eta:gg', {'InvM': 'eta_BFInvM', 'M': 'eta_BFM'}, path=main)
vertex.kFit('eta:gg', conf_level=0, fit_type='mass', path=main)

ma.reconstructDecay('eta:pipipiz -> pi+:sel pi-:sel pi0:eff50_May2020', 'abs(dM) < 0.07', dmID=1, path=main)
ma.variablesToExtraInfo('eta:pipipiz', {'InvM': 'eta_BFInvM', 'M': 'eta_BFM'}, path=main)
vertex.treeFit('eta:pipipiz', conf_level=0, massConstraint=['eta', 'pi0'], updateAllDaughters=True, path=main)

ma.copyLists('eta:all', ['eta:gg', 'eta:pipipiz'], path=main)

# create full signal side from D, eta, and lepton (neutrino is ignored)
ma.reconstructDecay('B+:Detamu -> anti-D0:all eta:all mu+:loose ?nu', '', dmID=0, path=main)
ma.reconstructDecay('B+:Detae -> anti-D0:all eta:all e+:corrected ?nu', '', dmID=1, path=main)
ma.copyLists('B+:eta', ['B+:Detamu', 'B+:Detae'], path=main)

# fit B+ vertex with mass constraint on D*, D, pi0 and KS, and require fit to converge
vertex.treeFit('B+:eta', conf_level=0, massConstraint=['D0', 'pi0', 'K_S0', 'eta'], updateAllDaughters=False, path=main)

ma.reconstructDecay('Upsilon -> B-:generic B+:eta', '', path=main)

# create Particle <-> MCParticle relations
ma.matchMCTruth('Upsilon', path=main)

ma.buildRestOfEvent('Upsilon', path=main)

# require no remaining tracks in the ROE
ma.appendROEMask('Upsilon', mask_name='ExtraMask', trackSelection=f'{looseIPCut}', eclClusterSelection='goodGamma', path=main)
ma.applyCuts('Upsilon', 'nROE_Charged(ExtraMask) == 0', path=main)

# build event shape with default particle lists and default selection cuts but deactivating all unnecessary calculations
ma.fillParticleList('gamma:evtshape', 'E > 0.05 and thetaInCDCAcceptance', path=main)
ma.fillParticleList('pi+:evtshape', 'pt > 0.1 and thetaInCDCAcceptance and abs(dz) < 3.0 and dr < 0.5', path=main)
ma.buildEventShape(inputListNames=['pi+:evtshape', 'gamma:evtshape'], cleoCones=True, collisionAxis=False, harmonicMoments=True, jets=False, sphericity=True, thrust=True, path=main)

# best candidate selection
ma.rankByLowest('Upsilon', 'Btag_rank', allowMultiRank=True, outputVariable='Y4S_Btag_rank', numBest=0, path=main)
variables.addAlias('Upsilon_Btag_rank', 'extraInfo(Y4S_Btag_rank)')
ma.applyCuts('Upsilon', 'Upsilon_Btag_rank == 1', path=main)

variableList = vc.event_shape

variables.addAlias('FlavorModeID', 'extraInfo(decayModeID)')
variables.addAlias('BDecayModeID', 'daughter(1, extraInfo(decayModeID))')
variables.addAlias('DDecayModeID', 'daughter(1, daughter(0, extraInfo(decayModeID)))')
variableList += ['FlavorModeID', 'BDecayModeID', 'DDecayModeID']
commonVariables = vc.kinematics + vc.inv_mass
vertexVariables = ['distance', 'significanceOfDistance', 'dx', 'dy', 'dz', 'x', 'y', 'z', 'x_uncertainty', 'y_uncertainty', 'z_uncertainty', 'dr', 'dphi', 'dcosTheta', 'chiProb']
commonVariables += vertexVariables

isSignalVariables = ['isSignal', 'isSignalAcceptMissingNeutrino', 'isSignalAcceptBremsPhotons', 'isCloneTrack', 'mcErrors', 'mcPDG']
MCKinematicsVariables = ['mcE', 'mcP', 'mcPX', 'mcPY', 'mcPZ']
MCOriginVariables = ['genMotherID', 'genMotherP', 'genMotherPDG', 'genParticleID', 'genMotherPDG(1)']
variables.addAlias('hasB0Ancestor', 'hasAncestor(511,1)')
variables.addAlias('hasB0barAncestor', 'hasAncestor(-511,1)')
variables.addAlias('hasBpAncestor', 'hasAncestor(521,1)')
variables.addAlias('hasBmAncestor', 'hasAncestor(-521,1)')
MCAncestorVariables = ['hasB0Ancestor', 'hasB0barAncestor', 'hasBpAncestor', 'hasBmAncestor']
MCVariables = isSignalVariables + MCKinematicsVariables + MCOriginVariables + MCAncestorVariables
for variable in MCVariables:
    ma.variablesToExtraInfo(f'B+:eta', {f'daughter(1, {variable})' : f'l_{variable}'}, path=main)
for variable in MCKinematicsVariables:
    ma.variablesToExtraInfo(f'B+:eta', {f'daughter(1, useCMSFrame({variable}))' : f'l_CMS_{variable}'}, path=main)
variableList += vu.create_aliases(MCVariables, 'daughter(1, extraInfo(l_{variable}))', 'l')
variableList += vu.create_aliases(MCKinematicsVariables, 'daughter(1, extraInfo(l_CMS_{variable}))', 'l_CMS')
variableList += vu.create_aliases(MCVariables, 'daughter(1, extraInfo(pi1_{variable}))', 'pi1')
commonVariables += MCVariables

variableList += vu.create_aliases(commonVariables, 'daughter(1, {variable})', 'Bsig')
variableList += vu.create_aliases(commonVariables, 'daughter(1, daughter(0, {variable}))', 'D')
BtagVariables = vc.deltae_mbc + vc.kinematics
BtagVariables += ['isSignal', 'mcErrors']
variableList += vu.create_aliases(BtagVariables, 'daughter(0, {variable})', 'Btag')
for variable in vc.kinematics:
    variables.addAlias(f'Btag_CMS_{variable}', f'daughter(0, useCMSFrame({variable}))')
    variables.addAlias(f'D_CMS_{variable}', f'daughter(1, daughter(0, useCMSFrame({variable})))')
    variables.addAlias(f'pi1_{variable}', f'daughter(1, extraInfo(pi1_{variable}))')
    variables.addAlias(f'pi1_CMS_{variable}', f'daughter(1, extraInfo(pi1_CMS_{variable}))')
    variables.addAlias(f'l_{variable}', f'daughter(1, extraInfo(l_{variable}))')
    variables.addAlias(f'l_CMS_{variable}', f'daughter(1, extraInfo(l_CMS_{variable}))')
    variableList += [f'Btag_CMS_{variable}', f'D_CMS_{variable}', f'pi1_{variable}', f'pi1_CMS_{variable}',
                     f'l_{variable}', f'l_CMS_{variable}']
for variable in ['PIDmu', 'PIDe', 'bremsCorrected', 'dr', 'dz']:
    variables.addAlias(f'l_{variable}', f'daughter(1, extraInfo(l_{variable}))')
    variableList += [f'l_{variable}']
for variable in ['PIDk', 'PIDppi', 'dr', 'dz']:
    variables.addAlias(f'pi1_{variable}', f'daughter(1, extraInfo(pi1_{variable}))')
    variableList += [f'pi1_{variable}']

variables.addAlias('Btag_sigProb', 'daughter(0, extraInfo(SignalProbability))')
variables.addAlias('Btag_rank', 'daughter(0, extraInfo(sigProb_rank))')
variables.addAlias('Btag_dmID', 'daughter(0, extraInfo(decayModeID))')
variableList += ['Btag_sigProb', 'Btag_rank', 'Btag_dmID']
variableList += ['Ecms']

ma.variablesToNtuple('Upsilon', variables=variableList, filename='advanced_ntuple.root', treename='tree', path=main)

# process events and print call statistics
basf2.process(main)
print(basf2.statistics)
