Changeset 106


Ignore:
Timestamp:
05/26/16 18:24:02 (8 years ago)
Author:
atrautsch
Message:

configurations pulled to top

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/GPTraining.java

    r104 r106  
    4444/** 
    4545 * Genetic Programming Trainer 
     46 *  
    4647 * 
     48 * - GPRun is a Run of a complete Genetic Programm Evolution, we want several complete runs. 
     49 * - GPVClassifier is the Validation Classifier 
     50 * - GPVVClassifier is the Validation-Voting Classifier 
     51 *  
     52 * config: <setwisetrainer name="GPTraining" param="GPVVClassifier" /> 
    4753 */ 
    4854public class GPTraining implements ISetWiseTrainingStrategy, IWekaCompatibleTrainer  { 
    4955     
    50     private GPVVClassifier classifier = new GPVVClassifier(); 
     56    private GPVClassifier classifier = null; 
    5157     
    5258    private int populationSize = 1000; 
     
    5460    private int initMaxDepth = 6; 
    5561    private int tournamentSize = 7; 
    56      
     62    private int maxGenerations = 50; 
     63    private double errorType2Weight = 1; 
     64    private int numberRuns = 200;  // 200 in the paper 
     65    private int maxDepth = 20;  // max depth within one program 
     66    private int maxNodes = 100;  // max nodes within one program 
     67 
    5768    @Override 
    5869    public void setParameter(String parameters) { 
    59         // todo, which type of classifier? GPV, GPVV? 
    60         // more config population size, etc. 
    61         // todo: voting for gpvv only 3 votes necessary? 
     70        if(parameters.equals("GPVVClassifier")) { 
     71            this.classifier = new GPVVClassifier(); 
     72            ((GPVVClassifier)this.classifier).configure(populationSize, initMinDepth, initMaxDepth, tournamentSize, maxGenerations, errorType2Weight, numberRuns, maxDepth, maxNodes); 
     73        }else if(parameters.equals("GPVClassifier")) { 
     74            this.classifier = new GPVClassifier(); 
     75            ((GPVClassifier)this.classifier).configure(populationSize, initMinDepth, initMaxDepth, tournamentSize, maxGenerations, errorType2Weight, numberRuns, maxDepth, maxNodes); 
     76        }else { 
     77            // default 
     78            this.classifier = new GPVVClassifier(); 
     79            ((GPVVClassifier)this.classifier).configure(populationSize, initMinDepth, initMaxDepth, tournamentSize, maxGenerations, errorType2Weight, numberRuns, maxDepth, maxNodes); 
     80        } 
    6281    } 
    6382 
     
    105124    } 
    106125     
    107     // one gprun, we want several for voting 
     126    /** 
     127     * One Run of a GP Classifier 
     128     * we want several runs to mitigate problems with local maxima/minima  
     129     */ 
    108130    public class GPRun extends AbstractClassifier { 
    109131        private static final long serialVersionUID = -4250422550107888789L; 
    110132 
    111         private int populationSize = 1000; 
    112         private int initMinDepth = 2; 
    113         private int initMaxDepth = 6; 
    114         private int tournamentSize = 7; 
    115         private int maxGenerations = 50; 
     133        private int populationSize; 
     134        private int initMinDepth; 
     135        private int initMaxDepth; 
     136        private int tournamentSize; 
     137        private int maxGenerations; 
     138        private double errorType2Weight; 
     139        private int maxDepth; 
     140        private int maxNodes; 
    116141         
    117142        private GPGenotype gp; 
    118143        private GPProblem problem; 
    119144         
    120         public void configure(int populationSize, int initMinDepth, int initMaxDepth, int tournamentSize, int maxGenerations) { 
     145        public void configure(int populationSize, int initMinDepth, int initMaxDepth, int tournamentSize, int maxGenerations, double errorType2Weight, int maxDepth, int maxNodes) { 
    121146            this.populationSize = populationSize; 
    122147            this.initMinDepth = initMinDepth; 
     
    124149            this.tournamentSize = tournamentSize; 
    125150            this.maxGenerations = maxGenerations; 
     151            this.errorType2Weight = errorType2Weight; 
     152            this.maxDepth = maxDepth; 
     153            this.maxNodes = maxNodes; 
    126154        } 
    127155         
     
    133161            return ((CrossPareGP)this.problem).getVariables(); 
    134162        } 
    135          
    136         public void setEvaldata(Instances testdata) { 
    137              
     163 
     164        @Override 
     165        public void buildClassifier(Instances traindata) throws Exception { 
     166            InstanceData train = new InstanceData(traindata);             
     167            this.problem = new CrossPareGP(train.getX(), train.getY(), this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize, this.errorType2Weight, this.maxDepth, this.maxNodes); 
     168            this.gp = problem.create(); 
     169            this.gp.evolve(this.maxGenerations); 
    138170        } 
    139171         
     
    142174         */ 
    143175        class CrossPareGP extends GPProblem { 
    144              
    145             //private static final long serialVersionUID = 7526472295622776147L; 
    146  
    147176            private double[][] instances; 
    148177            private boolean[] output; 
    149178 
     179            private int maxDepth; 
     180            private int maxNodes; 
     181             
    150182            private Variable[] x; 
    151183 
    152             public CrossPareGP(double[][] instances, boolean[] output, int populationSize, int minInitDept, int maxInitDepth, int tournamentSize) throws InvalidConfigurationException { 
     184            public CrossPareGP(double[][] instances, boolean[] output, int populationSize, int minInitDept, int maxInitDepth, int tournamentSize, double errorType2Weight, int maxDepth, int maxNodes) throws InvalidConfigurationException { 
    153185                super(new GPConfiguration()); 
    154186                 
    155187                this.instances = instances; 
    156188                this.output = output; 
     189                this.maxDepth = maxDepth; 
     190                this.maxNodes = maxNodes; 
    157191 
    158192                Configuration.reset(); 
    159193                GPConfiguration config = this.getGPConfiguration(); 
    160                 //config.reset(); 
    161194                 
    162195                this.x = new Variable[this.instances[0].length]; 
    163  
    164196                
    165197                for(int j=0; j < this.x.length; j++) { 
     
    170202                //config.setGPFitnessEvaluator(new DefaultGPFitnessEvaluator()); // bigger fitness is better 
    171203 
    172                 // from paper: 2-6 
    173204                config.setMinInitDepth(minInitDept); 
    174205                config.setMaxInitDepth(maxInitDepth); 
    175  
    176                 // missing from paper 
    177                 // config.setMaxDepth(20); 
    178  
     206                 
    179207                config.setCrossoverProb((float)0.60); 
    180208                config.setReproductionProb((float)0.10); 
     
    183211                config.setSelectionMethod(new TournamentSelector(tournamentSize)); 
    184212 
    185                 // from paper 1000 
    186213                config.setPopulationSize(populationSize); 
    187214 
    188                 // BranchTypingCross 
    189215                config.setMaxCrossoverDepth(4); 
    190                 config.setFitnessFunction(new CrossPareFitness(this.x, this.instances, this.output)); 
     216                config.setFitnessFunction(new CrossPareFitness(this.x, this.instances, this.output, errorType2Weight)); 
    191217                config.setStrictProgramCreation(true); 
    192218            } 
     
    230256                    comb, 
    231257                }; 
    232  
    233                 GPGenotype result = GPGenotype.randomInitialGenotype(config, types, argTypes, nodeSets, 20, true); // 20 = maxNodes, true = verbose output 
     258                 
     259                // we only have one chromosome so this suffices 
     260                int minDepths[] = {config.getMinInitDepth()}; 
     261                int maxDepths[] = {this.maxDepth}; 
     262                GPGenotype result = GPGenotype.randomInitialGenotype(config, types, argTypes, nodeSets, minDepths, maxDepths, this.maxNodes, false); // 40 = maxNodes, true = verbose output 
    234263 
    235264                return result; 
     
    250279            private boolean[] output; 
    251280 
    252             private double error_type2_weight = 1.0; 
     281            private double errorType2Weight = 1.0; 
    253282 
    254283            // needed in evaluate 
    255             private Object[] NO_ARGS = new Object[0]; 
     284            //private Object[] NO_ARGS = new Object[0]; 
    256285 
    257286            private double sfitness = 0.0f; 
    258             private int error_type1 = 0; 
    259             private int error_type2 = 0; 
    260  
    261             public CrossPareFitness(Variable[] x, double[][] instances, boolean[] output) { 
     287            private int errorType1 = 0; 
     288            private int errorType2 = 0; 
     289 
     290            public CrossPareFitness(Variable[] x, double[][] instances, boolean[] output, double errorType2Weight) { 
    262291                this.x = x; 
    263292                this.instances = instances; 
    264293                this.output = output; 
     294                this.errorType2Weight = errorType2Weight; 
    265295            } 
    266296 
    267297            public int getErrorType1() { 
    268                 return this.error_type1; 
     298                return this.errorType1; 
    269299            } 
    270300 
    271301            public int getErrorType2() { 
    272                 return this.error_type2; 
     302                return this.errorType2; 
    273303            } 
    274304 
     
    288318 
    289319                // count classification errors 
    290                 this.error_type1 = 0; 
    291                 this.error_type2 = 0; 
     320                this.errorType1 = 0; 
     321                this.errorType2 = 0; 
    292322 
    293323                for(int i=0; i < this.instances.length; i++) { 
     
    299329 
    300330                    // value gives us a double, if < 0.5 we set this instance as faulty 
    301                     value = program.execute_double(0, NO_ARGS);  // todo: test with this.x 
     331                    value = program.execute_double(0, this.x);  // todo: test with this.x 
    302332 
    303333                    if(value < 0.5) { 
    304334                        if(this.output[i] != true) { 
    305                             this.error_type1 += 1; 
     335                            this.errorType1 += 1; 
    306336                        } 
    307337                    }else { 
    308338                        if(this.output[i] == true) { 
    309                             this.error_type2 += 1; 
     339                            this.errorType2 += 1; 
    310340                        } 
    311341                    } 
     
    313343 
    314344                // now calc pfitness 
    315                 pfitness = (this.error_type1 + this.error_type2_weight * this.error_type2) / this.instances.length; 
     345                pfitness = (this.errorType1 + this.errorType2Weight * this.errorType2) / this.instances.length; 
    316346 
    317347                //System.out.println("pfitness: " + pfitness); 
     
    321351                if(program.getChromosome(0).getSize(0) < 10) { 
    322352                    program.setApplicationData(10.0f); 
    323                     this.sfitness = 10.0f; 
    324                     //System.out.println("wenige nodes: "+program.getChromosome(0).getSize(0)); 
    325                     //System.out.println(program.toStringNorm(0)); 
    326                 } 
    327  
    328                 // sfitness counts the number of nodes in the tree, if it is lower than 10 fitness is increased by 10 
     353                } 
    329354 
    330355                return pfitness; 
    331356            } 
    332357        } 
    333  
    334         @Override 
    335         public void buildClassifier(Instances traindata) throws Exception { 
    336             InstanceData train = new InstanceData(traindata);             
    337             this.problem = new CrossPareGP(train.getX(), train.getY(), this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize); 
    338             this.gp = problem.create(); 
    339             this.gp.evolve(this.maxGenerations); 
    340         } 
     358         
     359        /** 
     360         * Custom GT implementation used in the GP Algorithm. 
     361         */ 
     362         public class GT extends MathCommand implements ICloneable { 
     363              
     364             private static final long serialVersionUID = 113454184817L; 
     365 
     366             public GT(final GPConfiguration a_conf, java.lang.Class a_returnType) throws InvalidConfigurationException { 
     367                 super(a_conf, 2, a_returnType); 
     368             } 
     369 
     370             public String toString() { 
     371                 return "GT(&1, &2)"; 
     372             } 
     373 
     374             public String getName() { 
     375                 return "GT"; 
     376             }    
     377 
     378             public float execute_float(ProgramChromosome c, int n, Object[] args) { 
     379                 float f1 = c.execute_float(n, 0, args); 
     380                 float f2 = c.execute_float(n, 1, args); 
     381 
     382                 float ret = 1.0f; 
     383                 if(f1 > f2) { 
     384                     ret = 0.0f; 
     385                 } 
     386 
     387                 return ret; 
     388             } 
     389 
     390             public double execute_double(ProgramChromosome c, int n, Object[] args) { 
     391                 double f1 = c.execute_double(n, 0, args); 
     392                 double f2 = c.execute_double(n, 1, args); 
     393 
     394                 double ret = 1; 
     395                 if(f1 > f2)  { 
     396                     ret = 0; 
     397                 } 
     398                 return ret; 
     399             } 
     400 
     401             public Object clone() { 
     402                 try { 
     403                     GT result = new GT(getGPConfiguration(), getReturnType()); 
     404                     return result; 
     405                 }catch(Exception ex) { 
     406                     throw new CloneException(ex); 
     407                 } 
     408             } 
     409         } 
    341410    } 
    342411     
     
    349418     */ 
    350419    public class GPVVClassifier extends GPVClassifier { 
    351          
     420 
     421        private static final long serialVersionUID = -654710583852839901L; 
    352422        private List<Classifier> classifiers = null; 
    353423         
     
    362432            // each classifier is trained with one project from the set 
    363433            // then is evaluated on the rest 
     434            classifiers = new LinkedList<>(); 
    364435            for(int i=0; i < traindataSet.size(); i++) { 
    365436                 
     
    367438                LinkedList<Classifier> candidates = new LinkedList<>(); 
    368439                 
    369                 // 200 runs 
    370                  
    371                 for(int k=0; k < 200; k++) { 
     440                // number of runs 
     441                for(int k=0; k < this.numberRuns; k++) { 
    372442                    Classifier classifier = new GPRun(); 
     443                    ((GPRun)classifier).configure(this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize, this.maxGenerations, this.errorType2Weight, this.maxDepth, this.maxNodes); 
    373444                     
    374445                    // one project is training data 
     
    381452                            // if type1 and type2 errors are < 0.5 we allow the model in the final voting 
    382453                            errors = this.evaluate((GPRun)classifier, traindataSet.get(j)); 
    383                             if((errors[0] / traindataSet.get(j).numInstances()) < 0.5 && (errors[0] / traindataSet.get(j).numInstances()) < 0.5) { 
    384                                 candidates.add(classifier);                             
     454                            if((errors[0] < 0.5) && (errors[0] < 0.5)) { 
     455                                candidates.add(classifier); 
    385456                            } 
    386457                        } 
     
    406477                // now we have the best classifier for this training data 
    407478                classifiers.add(best); 
    408                  
    409479            } 
    410480        } 
     
    417487             
    418488            int vote_positive = 0; 
    419             int vote_negative = 0; 
    420489             
    421490            for (int i = 0; i < classifiers.size(); i++) { 
     
    432501                if(fitest.execute_double(0, vars) < 0.5) { 
    433502                    vote_positive += 1; 
    434                 }else { 
    435                     vote_negative += 1; 
    436503                } 
    437504            } 
     
    450517     * 
    451518     * for one test data set: 
    452      *   for one in 6 possible training data sets: 
    453      *     For 200 GP Runs: 
     519     *   for one in X possible training data sets: 
     520     *     For Y GP Runs: 
    454521     *       train one Classifier with this training data 
    455522     *       then evaluate the classifier with the remaining project 
     
    465532        private static final long serialVersionUID = 3708714057579101522L; 
    466533 
    467  
     534        protected int populationSize; 
     535        protected int initMinDepth; 
     536        protected int initMaxDepth; 
     537        protected int tournamentSize; 
     538        protected int maxGenerations; 
     539        protected double errorType2Weight; 
     540        protected int numberRuns; 
     541        protected int maxDepth; 
     542        protected int maxNodes; 
     543 
     544        /** 
     545         * Configure the GP Params and number of Runs 
     546         *  
     547         * @param populationSize 
     548         * @param initMinDepth 
     549         * @param initMaxDepth 
     550         * @param tournamentSize 
     551         * @param maxGenerations 
     552         * @param errorType2Weight 
     553         */ 
     554        public void configure(int populationSize, int initMinDepth, int initMaxDepth, int tournamentSize, int maxGenerations, double errorType2Weight, int numberRuns, int maxDepth, int maxNodes) { 
     555            this.populationSize = populationSize; 
     556            this.initMinDepth = initMinDepth; 
     557            this.initMaxDepth = initMaxDepth; 
     558            this.tournamentSize = tournamentSize; 
     559            this.maxGenerations = maxGenerations; 
     560            this.errorType2Weight = errorType2Weight; 
     561            this.numberRuns = numberRuns; 
     562            this.maxDepth = maxDepth; 
     563            this.maxNodes = maxNodes; 
     564        } 
     565         
    468566        /** Build the GP Multiple Data Sets Validation Classifier 
    469567         *  
     
    485583                 
    486584                // 200 runs 
    487                 for(int k=0; k < 200; k++) { 
     585                for(int k=0; k < this.numberRuns; k++) { 
    488586                    Classifier classifier = new GPRun(); 
     587                    ((GPRun)classifier).configure(this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize, this.maxGenerations, this.errorType2Weight, this.maxDepth, this.maxNodes); 
    489588                     
    490589                    // one project is training data 
     
    498597                            // if type1 and type2 errors are < 0.5 we allow the model in the final voting 
    499598                            errors = this.evaluate((GPRun)classifier, traindataSet.get(j)); 
    500                             if((errors[0] / traindataSet.get(j).numInstances()) < 0.5 && (errors[0] / traindataSet.get(j).numInstances()) < 0.5) { 
     599                            if((errors[0] < 0.5) && (errors[0] < 0.5)) { 
    501600                                candidates.add(classifier);                             
    502601                            } 
     
    549648        public void buildClassifier(Instances traindata) throws Exception { 
    550649            final Classifier classifier = new GPRun(); 
     650            ((GPRun)classifier).configure(populationSize, initMinDepth, initMaxDepth, tournamentSize, maxGenerations, errorType2Weight, this.maxDepth, this.maxNodes); 
    551651            classifier.buildClassifier(traindata); 
    552652            classifiers.add(classifier); 
     
    562662            int error_type1 = 0; 
    563663            int error_type2 = 0; 
    564             int number_instances = evalData.numInstances(); 
     664            int positive = 0; 
     665            int negative = 0; 
    565666             
    566667            for(Instance instance: evalData) { 
     
    571672                 
    572673                classification = fitest.execute_double(0, vars); 
     674                 
     675                // we need to count the absolutes of positives for percentage 
     676                if(instance.classValue() == 1.0) { 
     677                    positive +=1; 
     678                }else { 
     679                    negative +=1; 
     680                } 
    573681                 
    574682                // classification < 0.5 we say defective 
     
    584692            } 
    585693             
    586             double et1_per = error_type1 / number_instances; 
    587             double et2_per = error_type2 / number_instances;  
    588              
    589             // return some kind of fehlerquote? 
    590             //return (error_type1 + error_type2) / number_instances; 
    591             return new double[]{error_type1, error_type2}; 
     694            // return error types percentages for the types  
     695            double et1_per = error_type1 / negative; 
     696            double et2_per = error_type2 / positive;  
     697            return new double[]{et1_per, et2_per}; 
    592698        } 
    593699         
     
    614720        } 
    615721    } 
    616      
    617      
    618     /** 
    619     * Custom GT implementation from the paper 
    620     */ 
    621     public class GT extends MathCommand implements ICloneable { 
    622          
    623         private static final long serialVersionUID = 113454184817L; 
    624  
    625         public GT(final GPConfiguration a_conf, java.lang.Class a_returnType) throws InvalidConfigurationException { 
    626             super(a_conf, 2, a_returnType); 
    627         } 
    628  
    629         public String toString() { 
    630             return "GT(&1, &2)"; 
    631         } 
    632  
    633         public String getName() { 
    634             return "GT"; 
    635         }    
    636  
    637         public float execute_float(ProgramChromosome c, int n, Object[] args) { 
    638             float f1 = c.execute_float(n, 0, args); 
    639             float f2 = c.execute_float(n, 1, args); 
    640  
    641             float ret = 1.0f; 
    642             if(f1 > f2) { 
    643                 ret = 0.0f; 
    644             } 
    645  
    646             return ret; 
    647         } 
    648  
    649         public double execute_double(ProgramChromosome c, int n, Object[] args) { 
    650             double f1 = c.execute_double(n, 0, args); 
    651             double f2 = c.execute_double(n, 1, args); 
    652  
    653             double ret = 1; 
    654             if(f1 > f2)  { 
    655                 ret = 0; 
    656             } 
    657             return ret; 
    658         } 
    659  
    660         public Object clone() { 
    661             try { 
    662                 GT result = new GT(getGPConfiguration(), getReturnType()); 
    663                 return result; 
    664             }catch(Exception ex) { 
    665                 throw new CloneException(ex); 
    666             } 
    667         } 
    668     } 
    669722} 
Note: See TracChangeset for help on using the changeset viewer.