// tmvaTrain.cc P. Manek and G. Cowan, 2019 // simple program to train TMVA classifier (here, Fisher) #include #include #include #include #include #include #include #include int main(){ // Create ouput file, factory object and open the input file TFile* outputFile = TFile::Open("TMVA.root", "RECREATE"); TMVA::Factory* factory = new TMVA::Factory("tmvaTest", outputFile, ""); TMVA::DataLoader* dataLoader = new TMVA::DataLoader("dataset"); TFile* trainingFile = new TFile("../generate/trainingData.root"); TFile* testFile = new TFile("../generate/testData.root"); // get the TTree objects from the input files TTree* sigTrain = (TTree*)trainingFile->Get("sig"); TTree* bkgTrain = (TTree*)trainingFile->Get("bkg"); int nSigTrain = sigTrain->GetEntries(); int nBkgTrain = bkgTrain->GetEntries(); TTree* sigTest = (TTree*)testFile->Get("sig"); TTree* bkgTest = (TTree*)testFile->Get("bkg"); int nSigTest = sigTest->GetEntries(); int nBkgTest = bkgTest->GetEntries(); // global event weights (see below for setting event-wise weights) double sigWeight = 1.0; double bkgWeight = 1.0; dataLoader->AddSignalTree(sigTrain, sigWeight, TMVA::Types::kTraining); dataLoader->AddBackgroundTree(bkgTrain, bkgWeight, TMVA::Types::kTraining); dataLoader->AddSignalTree(sigTest, sigWeight, TMVA::Types::kTesting); dataLoader->AddBackgroundTree(bkgTest, bkgWeight, TMVA::Types::kTesting); // Define the input variables that shall be used for the MVA training // (the variables used in the expression must exist in the original TTree). dataLoader->AddVariable("x", 'F'); dataLoader->AddVariable("y", 'F'); dataLoader->AddVariable("z", 'F'); // Book MVA methods (see TMVA manual). factory->BookMethod(dataLoader, TMVA::Types::kFisher, "Fisher", "H:!V:Fisher"); // Train, test and evaluate all methods factory->TrainAllMethods(); factory->TestAllMethods(); factory->EvaluateAllMethods(); // Save the output and finish up outputFile->Close(); std::cout << "==> wrote root file TMVA.root" << std::endl; std::cout << "==> TMVAnalysis is done!" << std::endl; delete factory; delete dataLoader; return 0; }