source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/GPTraining.java @ 103

Last change on this file since 103 was 103, checked in by atrautsch, 8 years ago

GPTraining Implementation Update

File size: 20.3 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.util.List;
4
5import org.apache.commons.collections4.list.SetUniqueList;
6
7import weka.classifiers.AbstractClassifier;
8import weka.classifiers.Classifier;
9import weka.core.Instance;
10import weka.core.Instances;
11import org.apache.commons.lang3.ArrayUtils;
12import org.jgap.Configuration;
13import org.jgap.InvalidConfigurationException;
14import org.jgap.gp.CommandGene;
15import org.jgap.gp.GPProblem;
16
17import org.jgap.gp.function.Add;
18import org.jgap.gp.function.Multiply;
19import org.jgap.gp.function.Log;
20import org.jgap.gp.function.Subtract;
21import org.jgap.gp.function.Divide;
22import org.jgap.gp.function.Sine;
23import org.jgap.gp.function.Cosine;
24import org.jgap.gp.function.Max;
25import org.jgap.gp.function.Exp;
26
27import org.jgap.gp.impl.DeltaGPFitnessEvaluator;
28import org.jgap.gp.impl.GPConfiguration;
29import org.jgap.gp.impl.GPGenotype;
30import org.jgap.gp.impl.TournamentSelector;
31import org.jgap.gp.terminal.Terminal;
32import org.jgap.gp.GPFitnessFunction;
33import org.jgap.gp.IGPProgram;
34import org.jgap.gp.terminal.Variable;
35import org.jgap.gp.MathCommand;
36import org.jgap.util.ICloneable;
37
38import de.ugoe.cs.cpdp.util.WekaUtils;
39
40import org.jgap.gp.impl.ProgramChromosome;
41import org.jgap.util.CloneException;
42
43/**
44 * Genetic Programming Trainer
45 *
46 */
47public class GPTraining implements ISetWiseTrainingStrategy, IWekaCompatibleTrainer  {
48   
49    private GPVClassifier classifier = new GPVClassifier();
50   
51    private int populationSize = 1000;
52    private int initMinDepth = 2;
53    private int initMaxDepth = 6;
54    private int tournamentSize = 7;
55   
56    @Override
57    public void setParameter(String parameters) {
58        // todo, which type of classifier? GPV, GPVV?
59        // more config population size, etc.
60        // todo: voting for gpvv only 3 votes necessary?
61    }
62
63    @Override
64    public void apply(SetUniqueList<Instances> traindataSet) {
65        try {
66            classifier.buildClassifier(traindataSet);
67        }catch(Exception e) {
68            throw new RuntimeException(e);
69        }
70    }
71
72    @Override
73    public String getName() {
74        return "GPTraining";
75    }
76
77    @Override
78    public Classifier getClassifier() {
79        return this.classifier;
80    }
81   
82    public class InstanceData {
83        private double[][] instances_x;
84        private boolean[] instances_y;
85       
86        public InstanceData(Instances instances) {
87            this.instances_x = new double[instances.numInstances()][instances.numAttributes()-1];
88            this.instances_y = new boolean[instances.numInstances()];
89           
90            Instance current;
91            for(int i=0; i < this.instances_x.length; i++) {
92                current = instances.get(i);
93                this.instances_x[i] = WekaUtils.instanceValues(current);
94                this.instances_y[i] = 1.0 == current.classValue();
95            }
96        }
97       
98        public double[][] getX() {
99            return instances_x;
100        }
101        public boolean[] getY() {
102            return instances_y;
103        }
104    }
105   
106    // one gprun, we want several for voting
107    public class GPRun extends AbstractClassifier {
108        private static final long serialVersionUID = -4250422550107888789L;
109
110        private int populationSize = 1000;
111        private int initMinDepth = 2;
112        private int initMaxDepth = 6;
113        private int tournamentSize = 7;
114        private int maxGenerations = 50;
115       
116        private GPGenotype gp;
117        private GPProblem problem;
118       
119        public void configure(int populationSize, int initMinDepth, int initMaxDepth, int tournamentSize, int maxGenerations) {
120            this.populationSize = populationSize;
121            this.initMinDepth = initMinDepth;
122            this.initMaxDepth = initMaxDepth;
123            this.tournamentSize = tournamentSize;
124            this.maxGenerations = maxGenerations;
125        }
126       
127        public GPGenotype getGp() {
128            return this.gp;
129        }
130       
131        public Variable[] getVariables() {
132            return ((CrossPareGP)this.problem).getVariables();
133        }
134       
135        public void setEvaldata(Instances testdata) {
136           
137        }
138       
139        /**
140         * GPProblem implementation
141         */
142        class CrossPareGP extends GPProblem {
143           
144            //private static final long serialVersionUID = 7526472295622776147L;
145
146            private double[][] instances;
147            private boolean[] output;
148
149            private Variable[] x;
150
151            public CrossPareGP(double[][] instances, boolean[] output, int populationSize, int minInitDept, int maxInitDepth, int tournamentSize) throws InvalidConfigurationException {
152                super(new GPConfiguration());
153               
154                this.instances = instances;
155                this.output = output;
156
157                Configuration.reset();
158                GPConfiguration config = this.getGPConfiguration();
159                //config.reset();
160               
161                this.x = new Variable[this.instances[0].length];
162
163               
164                for(int j=0; j < this.x.length; j++) {
165                    this.x[j] = Variable.create(config, "X"+j, CommandGene.DoubleClass);   
166                }
167
168                config.setGPFitnessEvaluator(new DeltaGPFitnessEvaluator()); // smaller fitness is better
169                //config.setGPFitnessEvaluator(new DefaultGPFitnessEvaluator()); // bigger fitness is better
170
171                // from paper: 2-6
172                config.setMinInitDepth(minInitDept);
173                config.setMaxInitDepth(maxInitDepth);
174
175                // missing from paper
176                // config.setMaxDepth(20);
177
178                config.setCrossoverProb((float)0.60);
179                config.setReproductionProb((float)0.10);
180                config.setMutationProb((float)0.30);
181
182                config.setSelectionMethod(new TournamentSelector(tournamentSize));
183
184                // from paper 1000
185                config.setPopulationSize(populationSize);
186
187                // BranchTypingCross
188                config.setMaxCrossoverDepth(4);
189                config.setFitnessFunction(new CrossPareFitness(this.x, this.instances, this.output));
190                config.setStrictProgramCreation(true);
191            }
192
193            // used for running the fitness function again for testing
194            public Variable[] getVariables() {
195                return this.x;
196            }
197
198
199            public GPGenotype create() throws InvalidConfigurationException {
200                GPConfiguration config = this.getGPConfiguration();
201
202                // return type
203                Class[] types = {CommandGene.DoubleClass};
204
205                // Arguments of result-producing chromosome: none
206                Class[][] argTypes = { {} };
207
208                // variables + functions, we set the variables with the values of the instances here
209                CommandGene[] vars = new CommandGene[this.instances[0].length];
210                for(int j=0; j < this.instances[0].length; j++) {
211                    vars[j] = this.x[j];
212                }
213                CommandGene[] funcs = {
214                    new Add(config, CommandGene.DoubleClass),
215                    new Subtract(config, CommandGene.DoubleClass),
216                    new Multiply(config, CommandGene.DoubleClass),
217                    new Divide(config, CommandGene.DoubleClass),
218                    new Sine(config, CommandGene.DoubleClass),
219                    new Cosine(config, CommandGene.DoubleClass),
220                    new Exp(config, CommandGene.DoubleClass),
221                    new Log(config, CommandGene.DoubleClass),
222                    new GT(config, CommandGene.DoubleClass),
223                    new Max(config, CommandGene.DoubleClass),
224                    new Terminal(config, CommandGene.DoubleClass, -100.0, 100.0, true), // min, max, whole numbers
225                };
226
227                CommandGene[] comb = (CommandGene[])ArrayUtils.addAll(vars, funcs);
228                CommandGene[][] nodeSets = {
229                    comb,
230                };
231
232                GPGenotype result = GPGenotype.randomInitialGenotype(config, types, argTypes, nodeSets, 20, true); // 20 = maxNodes, true = verbose output
233
234                return result;
235            }
236        }
237
238       
239        /**
240         * Fitness function
241         */
242        class CrossPareFitness extends GPFitnessFunction {
243           
244            private static final long serialVersionUID = 75234832484387L;
245
246            private Variable[] x;
247
248            private double[][] instances;
249            private boolean[] output;
250
251            private double error_type2_weight = 1.0;
252
253            // needed in evaluate
254            private Object[] NO_ARGS = new Object[0];
255
256            private double sfitness = 0.0f;
257            private int error_type1 = 0;
258            private int error_type2 = 0;
259
260            public CrossPareFitness(Variable[] x, double[][] instances, boolean[] output) {
261                this.x = x;
262                this.instances = instances;
263                this.output = output;
264            }
265
266            public int getErrorType1() {
267                return this.error_type1;
268            }
269
270            public int getErrorType2() {
271                return this.error_type2;
272            }
273
274            public double getSecondFitness() {
275                return this.sfitness;
276            }
277
278            public int getNumInstances() {
279                return this.instances.length;
280            }
281
282            @Override
283            protected double evaluate(final IGPProgram program) {
284                double pfitness = 0.0f;
285                this.sfitness = 0.0f;
286                double value = 0.0f;
287
288                // count classification errors
289                this.error_type1 = 0;
290                this.error_type2 = 0;
291
292                for(int i=0; i < this.instances.length; i++) {
293
294                    // requires that we have a variable for each column of our dataset (attribute of instance)
295                    for(int j=0; j < this.x.length; j++) {
296                        this.x[j].set(this.instances[i][j]);
297                    }
298
299                    // value gives us a double, if < 0.5 we set this instance as faulty
300                    value = program.execute_double(0, NO_ARGS);  // todo: test with this.x
301
302                    if(value < 0.5) {
303                        if(this.output[i] != true) {
304                            this.error_type1 += 1;
305                        }
306                    }else {
307                        if(this.output[i] == true) {
308                            this.error_type2 += 1;
309                        }
310                    }
311                }
312
313                // now calc pfitness
314                pfitness = (this.error_type1 + this.error_type2_weight * this.error_type2) / this.instances.length;
315
316                //System.out.println("pfitness: " + pfitness);
317
318                // number of nodes in the programm, if lower then 10 we assign sFitness of 10
319                // we can set metadata with setProgramData to save this
320                if(program.getChromosome(0).getSize(0) < 10) {
321                    program.setApplicationData(10.0f);
322                    this.sfitness = 10.0f;
323                    //System.out.println("wenige nodes: "+program.getChromosome(0).getSize(0));
324                    //System.out.println(program.toStringNorm(0));
325                }
326
327                // sfitness counts the number of nodes in the tree, if it is lower than 10 fitness is increased by 10
328
329                return pfitness;
330            }
331        }
332
333        @Override
334        public void buildClassifier(Instances traindata) throws Exception {
335            InstanceData train = new InstanceData(traindata);           
336            this.problem = new CrossPareGP(train.getX(), train.getY(), this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize);
337            this.gp = problem.create();
338            this.gp.evolve(this.maxGenerations);
339        }
340    }
341   
342    /**
343     * GP Multiple Data Sets Validation-Voting Classifier
344     *
345     *
346     */
347    public class GPVVClassifier extends GPVClassifier {
348       
349        private List<Classifier> classifiers = null;
350       
351        @Override
352        public void buildClassifier(Instances arg0) throws Exception {
353            // TODO Auto-generated method stub
354           
355        }
356       
357        public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
358
359            // each classifier is trained with one project from the set
360            // then is evaluated on the rest
361            for(int i=0; i < traindataSet.size(); i++) {
362                Classifier classifier = new GPRun();
363               
364                // one project is training data
365                classifier.buildClassifier(traindataSet.get(i));
366               
367                double[] errors;
368               
369                // rest of the set is evaluation data, we evaluate now
370                for(int j=0; j < traindataSet.size(); j++) {
371                    if(j != i) {
372                        // if type1 and type2 errors are < 0.5 we allow the model in the final voting
373                        errors = this.evaluate((GPRun)classifier, traindataSet.get(j));
374                        if((errors[0] / traindataSet.get(j).numInstances()) < 0.5 && (errors[0] / traindataSet.get(j).numInstances()) < 0.5) {
375                            classifiers.add(classifier);                           
376                        }
377                    }
378                }
379            }
380        }
381       
382        /**
383         * Use the remaining classifiers for our voting
384         */
385        @Override
386        public double classifyInstance(Instance instance) {
387           
388            int vote_positive = 0;
389            int vote_negative = 0;
390           
391            for (int i = 0; i < classifiers.size(); i++) {
392                Classifier classifier = classifiers.get(i);
393               
394                GPGenotype gp = ((GPRun)classifier).getGp();
395                Variable[] vars = ((GPRun)classifier).getVariables();
396               
397                IGPProgram fitest = gp.getAllTimeBest();  // all time fitest
398                for(int j = 0; j < instance.numAttributes()-1; j++) {
399                   vars[j].set(instance.value(j));
400                }
401               
402                if(fitest.execute_double(0, vars) < 0.5) {
403                    vote_positive += 1;
404                }else {
405                    vote_negative += 1;
406                }
407            }
408           
409            if(vote_positive >= 3) {
410                return 1.0;
411            }else {
412                return 0.0;
413            }
414        }
415    }
416   
417    /**
418     * GP Multiple Data Sets Validation Classifier
419     *
420     *
421     * for one test data set:
422     *   for one in 6 possible training data sets:
423     *     For 200 GP Runs:
424     *       train one Classifier with this training data
425     *       then evaluate the classifier with the remaining project
426     *       if the candidate model performs bad (error type1 or type2 > 50%) discard it
427     * for the remaining model candidates the best one is used
428     *
429     */
430    public class GPVClassifier extends AbstractClassifier {
431       
432        private Classifier best = null;
433       
434        private static final long serialVersionUID = 3708714057579101522L;
435
436
437        /** Build the GP Multiple Data Sets Validation Classifier
438         *
439         * - Traindata one of the Instances of the Set (which one? The firsT? as it is a list?)
440         * - Testdata one other Instances of the Set (the next one? chose randomly?)
441         * - Evaluation the rest of the instances
442         *
443         * @param traindataSet
444         * @throws Exception
445         */
446        public void buildClassifier(SetUniqueList<Instances> traindataSet) throws Exception {
447
448            // each classifier is trained with one project from the set
449            // then is evaluated on the rest
450            for(int i=0; i < traindataSet.size(); i++) {
451                Classifier classifier = new GPRun();
452               
453                // one project is training data
454                classifier.buildClassifier(traindataSet.get(i));
455               
456                // rest of the set is evaluation data, we evaluate now
457                double smallest_error_count = Double.MAX_VALUE;
458                double[] errors;
459                for(int j=0; j < traindataSet.size(); j++) {
460                    if(j != i) {
461                        errors = this.evaluate((GPRun)classifier, traindataSet.get(j));
462                        if(errors[0]+errors[1] < smallest_error_count) {
463                            this.best = classifier;
464                        }
465                    }
466                }
467            }
468        }
469       
470        @Override
471        public void buildClassifier(Instances traindata) throws Exception {
472            final Classifier classifier = new GPRun();
473            classifier.buildClassifier(traindata);
474            best = classifier;
475        }
476       
477        public double[] evaluate(GPRun classifier, Instances evalData) {
478            GPGenotype gp = classifier.getGp();
479            Variable[] vars = classifier.getVariables();
480           
481            IGPProgram fitest = gp.getAllTimeBest();  // selects the fitest of all not just the last generation
482           
483            double classification;
484            int error_type1 = 0;
485            int error_type2 = 0;
486            int number_instances = evalData.numInstances();
487           
488            for(Instance instance: evalData) {
489               
490                for(int i = 0; i < instance.numAttributes()-1; i++) {
491                    vars[i].set(instance.value(i));
492                }
493               
494                classification = fitest.execute_double(0, vars);
495               
496                // classification < 0.5 we say defective
497                if(classification < 0.5) {
498                    if(instance.classValue() != 1.0) {
499                        error_type1 += 1;
500                    }
501                }else {
502                    if(instance.classValue() == 1.0) {
503                        error_type2 += 1;
504                    }
505                }
506            }
507           
508            double et1_per = error_type1 / number_instances;
509            double et2_per = error_type2 / number_instances;
510           
511            // return some kind of fehlerquote?
512            //return (error_type1 + error_type2) / number_instances;
513            return new double[]{error_type1, error_type2};
514        }
515       
516        /**
517         * Use only the best classifier from our evaluation phase
518         */
519        @Override
520        public double classifyInstance(Instance instance) {
521            GPGenotype gp = ((GPRun)best).getGp();
522            Variable[] vars = ((GPRun)best).getVariables();
523           
524            IGPProgram fitest = gp.getAllTimeBest();  // all time fitest
525            for(int i = 0; i < instance.numAttributes()-1; i++) {
526               vars[i].set(instance.value(i));
527            }
528           
529            double classification = fitest.execute_double(0, vars);
530           
531            if(classification < 0.5) {
532                return 1.0;
533            }else {
534                return 0.0;
535            }
536        }
537    }
538   
539   
540    /**
541    * Custom GT implementation from the paper
542    */
543    public class GT extends MathCommand implements ICloneable {
544       
545        private static final long serialVersionUID = 113454184817L;
546
547        public GT(final GPConfiguration a_conf, java.lang.Class a_returnType) throws InvalidConfigurationException {
548            super(a_conf, 2, a_returnType);
549        }
550
551        public String toString() {
552            return "GT(&1, &2)";
553        }
554
555        public String getName() {
556            return "GT";
557        }   
558
559        public float execute_float(ProgramChromosome c, int n, Object[] args) {
560            float f1 = c.execute_float(n, 0, args);
561            float f2 = c.execute_float(n, 1, args);
562
563            float ret = 1.0f;
564            if(f1 > f2) {
565                ret = 0.0f;
566            }
567
568            return ret;
569        }
570
571        public double execute_double(ProgramChromosome c, int n, Object[] args) {
572            double f1 = c.execute_double(n, 0, args);
573            double f2 = c.execute_double(n, 1, args);
574
575            double ret = 1;
576            if(f1 > f2)  {
577                ret = 0;
578            }
579            return ret;
580        }
581
582        public Object clone() {
583            try {
584                GT result = new GT(getGPConfiguration(), getReturnType());
585                return result;
586            }catch(Exception ex) {
587                throw new CloneException(ex);
588            }
589        }
590    }
591}
Note: See TracBrowser for help on using the repository browser.