source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/GPTraining.java @ 101

Last change on this file since 101 was 93, checked in by atrautsch, 9 years ago

init gptraining

File size: 13.5 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import org.apache.commons.collections4.list.SetUniqueList;
4
5import weka.classifiers.AbstractClassifier;
6import weka.classifiers.Classifier;
7import weka.core.Instance;
8import weka.core.Instances;
9import org.apache.commons.lang3.ArrayUtils;
10
11import org.jgap.InvalidConfigurationException;
12import org.jgap.gp.CommandGene;
13import org.jgap.gp.GPProblem;
14
15import org.jgap.gp.function.Add;
16import org.jgap.gp.function.Multiply;
17import org.jgap.gp.function.Log;
18import org.jgap.gp.function.Subtract;
19import org.jgap.gp.function.Divide;
20import org.jgap.gp.function.Sine;
21import org.jgap.gp.function.Cosine;
22import org.jgap.gp.function.Max;
23import org.jgap.gp.function.Exp;
24
25import org.jgap.gp.impl.DeltaGPFitnessEvaluator;
26import org.jgap.gp.impl.GPConfiguration;
27import org.jgap.gp.impl.GPGenotype;
28import org.jgap.gp.impl.TournamentSelector;
29import org.jgap.gp.terminal.Terminal;
30import org.jgap.gp.GPFitnessFunction;
31import org.jgap.gp.IGPProgram;
32import org.jgap.gp.terminal.Variable;
33import org.jgap.gp.MathCommand;
34import org.jgap.util.ICloneable;
35
36import org.jgap.gp.impl.ProgramChromosome;
37import org.jgap.util.CloneException;
38
39/**
40 * Genetic Programming Trainer
41 *
42 */
43public class GPTraining implements ISetWiseTrainingStrategy, IWekaCompatibleTrainer  {
44   
45    private final GPClassifier classifier = new GPClassifier();
46   
47    private int populationSize = 1000;
48    private int initMinDepth = 2;
49    private int initMaxDepth = 6;
50    private int tournamentSize = 7;
51   
52    @Override
53    public void setParameter(String parameters) {
54        System.out.println("setParameters");
55    }
56
57    @Override
58    public void apply(SetUniqueList<Instances> traindataSet) {
59        System.out.println("apply");
60        for (Instances traindata : traindataSet) {
61            try {
62                classifier.buildClassifier(traindata);
63            }catch(Exception e) {
64                throw new RuntimeException(e);
65            }
66        }
67    }
68
69    @Override
70    public String getName() {
71        System.out.println("getName");
72        return "GPTraining";
73    }
74
75    @Override
76    public Classifier getClassifier() {
77        System.out.println("getClassifier");
78        return this.classifier;
79    }
80   
81    public class InstanceData {
82        private double[][] instances_x;
83        private boolean[] instances_y;
84       
85        public InstanceData(Instances instances) {
86            this.instances_x = new double[instances.numInstances()][instances.numAttributes()-1];
87
88            Instance current;
89            for(int i=0; i < this.instances_x.length; i++) {
90                current = instances.get(i);
91                for(int j=0; j < this.instances_x[0].length; j++) {
92                    this.instances_x[i][j] = current.value(j);
93                }
94               
95                this.instances_y[i] = current.stringValue(instances.classIndex()).equals("Y");
96            }
97        }
98       
99        public double[][] getX() {
100            return instances_x;
101        }
102        public boolean[] getY() {
103            return instances_y;
104        }
105    }
106   
107    public class GPClassifier extends AbstractClassifier {
108
109        private static final long serialVersionUID = 3708714057579101522L;
110
111        private int populationSize = 1000;
112        private int initMinDepth = 2;
113        private int initMaxDepth = 6;
114        private int tournamentSize = 7;
115        private int maxGenerations = 50;
116       
117        private GPGenotype gp;
118        private GPProblem problem;
119       
120        public void configure(int populationSize, int initMinDepth, int initMaxDepth, int tournamentSize, int maxGenerations) {
121            this.populationSize = populationSize;
122            this.initMinDepth = initMinDepth;
123            this.initMaxDepth = initMaxDepth;
124            this.tournamentSize = tournamentSize;
125            this.maxGenerations = maxGenerations;
126        }
127       
128        @Override
129        public void buildClassifier(Instances instances) throws Exception {
130            // load instances into double[][] and boolean[]
131            InstanceData train = new InstanceData(instances);
132            this.problem = new CrossPareGP(train.getX(), train.getY(), this.populationSize, this.initMinDepth, this.initMaxDepth, this.tournamentSize);
133           
134            this.gp = problem.create();
135            this.gp.evolve(this.maxGenerations);
136        }
137       
138        @Override
139        public double classifyInstance(Instance instance) {
140            Variable[] vars = ((CrossPareGP)this.problem).getVariables();
141           
142            double[][] x = new double[1][instance.numAttributes()-1];
143            boolean[] y = new boolean[1];
144           
145            for(int i = 0; i < instance.numAttributes()-1; i++) {
146                x[0][i] = instance.value(i);
147            }
148            y[0] = instance.stringValue(instance.classIndex()).equals("Y");
149           
150            CrossPareFitness test = new CrossPareFitness(vars, x, y);
151            IGPProgram fitest = gp.getAllTimeBest();
152           
153            double sfitness = test.evaluate(fitest);
154           
155            // korrekt sind wir wenn wir geringe fitness haben?
156            if(sfitness < 0.5) {
157                return 1.0;
158            }
159            return 0;
160           
161        }
162
163        /**
164         * GPProblem implementation
165         */
166        class CrossPareGP extends GPProblem {
167           
168            private static final long serialVersionUID = 7526472295622776147L;
169
170            private double[][] instances;
171            private boolean[] output;
172
173            private Variable[] x;
174
175            public CrossPareGP(double[][] instances, boolean[] output, int populationSize, int minInitDept, int maxInitDepth, int tournamentSize) throws InvalidConfigurationException {
176                super(new GPConfiguration());
177
178                this.instances = instances;
179                this.output = output;
180
181                GPConfiguration config = this.getGPConfiguration();
182
183                this.x = new Variable[this.instances[0].length];
184
185                for(int j=0; j < this.x.length; j++) {
186                    this.x[j] = Variable.create(config, "X"+j, CommandGene.DoubleClass);   
187                }
188
189                config.setGPFitnessEvaluator(new DeltaGPFitnessEvaluator()); // smaller fitness is better
190                //config.setGPFitnessEvaluator(new DefaultGPFitnessEvaluator()); // bigger fitness is better
191
192                // from paper: 2-6
193                config.setMinInitDepth(minInitDept);
194                config.setMaxInitDepth(maxInitDepth);
195
196                // missing from paper
197                // config.setMaxDepth(20);
198
199                config.setCrossoverProb((float)0.60);
200                config.setReproductionProb((float)0.10);
201                config.setMutationProb((float)0.30);
202
203                config.setSelectionMethod(new TournamentSelector(tournamentSize));
204
205                // from paper 1000
206                config.setPopulationSize(populationSize);
207
208                // BranchTypingCross
209                config.setMaxCrossoverDepth(4);
210                config.setFitnessFunction(new CrossPareFitness(this.x, this.instances, this.output));
211                config.setStrictProgramCreation(true);
212            }
213
214            // used for running the fitness function again for testing
215            public Variable[] getVariables() {
216                return this.x;
217            }
218
219
220            public GPGenotype create() throws InvalidConfigurationException {
221                GPConfiguration config = this.getGPConfiguration();
222
223                // return type
224                Class[] types = {CommandGene.DoubleClass};
225
226                // Arguments of result-producing chromosome: none
227                Class[][] argTypes = { {} };
228
229                // variables + functions
230                CommandGene[] vars = new CommandGene[this.instances[0].length];
231                for(int j=0; j < this.instances[0].length; j++) {
232                    vars[j] = this.x[j];
233                }
234                CommandGene[] funcs = {
235                    new Add(config, CommandGene.DoubleClass),
236                    new Subtract(config, CommandGene.DoubleClass),
237                    new Multiply(config, CommandGene.DoubleClass),
238                    new Divide(config, CommandGene.DoubleClass),
239                    new Sine(config, CommandGene.DoubleClass),
240                    new Cosine(config, CommandGene.DoubleClass),
241                    new Exp(config, CommandGene.DoubleClass),
242                    new Log(config, CommandGene.DoubleClass),
243                    new GT(config, CommandGene.DoubleClass),
244                    new Max(config, CommandGene.DoubleClass),
245                    new Terminal(config, CommandGene.DoubleClass, -100.0, 100.0, true), // min, max, whole numbers
246                };
247
248                CommandGene[] comb = (CommandGene[])ArrayUtils.addAll(vars, funcs);
249                CommandGene[][] nodeSets = {
250                    comb,
251                };
252
253                GPGenotype result = GPGenotype.randomInitialGenotype(config, types, argTypes, nodeSets, 20, true); // 20 = maxNodes, true = verbose output
254
255                return result;
256            }
257        }
258       
259        /**
260         * Fitness function
261         */
262        class CrossPareFitness extends GPFitnessFunction {
263           
264            private static final long serialVersionUID = 75234832484387L;
265
266            private Variable[] x;
267
268            private double[][] instances;
269            private boolean[] output;
270
271            private double error_type2_weight = 1.0;
272
273            // needed in evaluate
274            private Object[] NO_ARGS = new Object[0];
275
276            private double sfitness = 0.0f;
277            private int error_type1 = 0;
278            private int error_type2 = 0;
279
280            public CrossPareFitness(Variable[] x, double[][] instances, boolean[] output) {
281                this.x = x;
282                this.instances = instances;
283                this.output = output;
284            }
285
286            public int getErrorType1() {
287                return this.error_type1;
288            }
289
290            public int getErrorType2() {
291                return this.error_type2;
292            }
293
294            public double getSecondFitness() {
295                return this.sfitness;
296            }
297
298            public int getNumInstances() {
299                return this.instances.length;
300            }
301
302            @Override
303            protected double evaluate(final IGPProgram program) {
304                double pfitness = 0.0f;
305                this.sfitness = 0.0f;
306                double value = 0.0f;
307
308                // count classification errors
309                this.error_type1 = 0;
310                this.error_type2 = 0;
311
312                for(int i=0; i < this.instances.length; i++) {
313
314                    // requires that we have a variable for each column of our dataset (attribute of instance)
315                    for(int j=0; j < this.x.length; j++) {
316                        this.x[j].set(this.instances[i][j]);
317                    }
318
319                    // value gives us a double, if > 0.5 we set this instance as faulty
320                    value = program.execute_double(0, NO_ARGS);
321
322                    if(value < 0.5) {
323                        if(this.output[i] != true) {
324                            this.error_type1 += 1;
325                        }
326                    }else {
327                        if(this.output[i] == true) {
328                            this.error_type2 += 1;
329                        }
330                    }
331                }
332
333                // now calc pfitness
334                pfitness = (this.error_type1 + this.error_type2_weight * this.error_type2) / this.instances.length;
335
336                //System.out.println("pfitness: " + pfitness);
337
338                // number of nodes in the programm, if lower then 10 we assign sFitness of 10
339                if(program.getChromosome(0).getSize(0) < 10) {
340                    this.sfitness = 10.0f;
341                    //System.out.println("wenige nodes: "+program.getChromosome(0).getSize(0));
342                    //System.out.println(program.toStringNorm(0));
343                }
344
345                // sfitness counts the number of nodes in the tree, if it is lower than 10 fitness is increased by 10
346
347                return pfitness;
348            }
349        }
350    }
351   
352   
353    /**
354    * Custom GT implementation from the paper
355    */
356    public class GT extends MathCommand implements ICloneable {
357       
358        private static final long serialVersionUID = 113454184817L;
359
360        public GT(final GPConfiguration a_conf, java.lang.Class a_returnType) throws InvalidConfigurationException {
361            super(a_conf, 2, a_returnType);
362        }
363
364        public String toString() {
365            return "GT(&1, &2)";
366        }
367
368        public String getName() {
369            return "GT";
370        }   
371
372        public float execute_float(ProgramChromosome c, int n, Object[] args) {
373            float f1 = c.execute_float(n, 0, args);
374            float f2 = c.execute_float(n, 1, args);
375
376            float ret = 1.0f;
377            if(f1 > f2) {
378                ret = 0.0f;
379            }
380
381            return ret;
382        }
383
384        public double execute_double(ProgramChromosome c, int n, Object[] args) {
385            double f1 = c.execute_double(n, 0, args);
386            double f2 = c.execute_double(n, 1, args);
387
388            double ret = 1;
389            if(f1 > f2)  {
390                ret = 0;
391            }
392            return ret;
393        }
394
395        public Object clone() {
396            try {
397                GT result = new GT(getGPConfiguration(), getReturnType());
398                return result;
399            }catch(Exception ex) {
400                throw new CloneException(ex);
401            }
402        }
403    }
404}
Note: See TracBrowser for help on using the repository browser.