source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalTraining2.java @ 18

Last change on this file since 18 was 17, checked in by atrautsch, 10 years ago

Aufräumarbeiten durchgeführt und jetzt auch eine Version die Funktioniert.
Debugausgaben habe ich erst mal nur auskommentiert.

File size: 23.3 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.io.PrintStream;
4import java.util.ArrayList;
5import java.util.HashMap;
6import java.util.HashSet;
7import java.util.Iterator;
8import java.util.Random;
9import java.util.Set;
10import java.util.logging.Level;
11
12import org.apache.commons.io.output.NullOutputStream;
13
14import de.ugoe.cs.cpdp.training.QuadTree;
15import de.ugoe.cs.util.console.Console;
16import weka.classifiers.AbstractClassifier;
17import weka.classifiers.Classifier;
18import weka.core.DenseInstance;
19import weka.core.EuclideanDistance;
20import weka.core.Instance;
21import weka.core.Instances;
22import weka.filters.Filter;
23import weka.filters.unsupervised.attribute.Remove;
24
25/**
26 * Trainer with reimplementation of WHERE clustering algorithm from:
27 * Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman,
28 * Forrest Shull, Burak Turhan, Thomas Zimmermann,
29 * "Local versus Global Lessons for Defect Prediction and Effort Estimation,"
30 * IEEE Transactions on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 
31 *
32 * With WekaLocalTraining2 we do the following:
33 * 1) Run the Fastmap algorithm on all training data, let it calculate the 2 most significant
34 *    dimensions and projections of each instance to these dimensions
35 * 2) With these 2 dimensions we span a QuadTree which gets recursively split on median(x) and median(y) values.
36 * 3) We cluster the QuadTree nodes together if they have similar density (50%)
37 * 4) We save the clusters and their training data
38 * 5) We only use clusters with > ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this cluster
39 * 6) We train a Weka classifier for each cluster with the clusters training data
40 * 7) We recalculate Fastmap distances for a single instance with the old pivots and then try to find a cluster containing the coords of the instance.
41 * 7.1.) If we can not find a cluster (due to coords outside of all clusters) we find the nearest cluster.
42 * 8) We classify the Instance with the classifier and traindata from the Cluster we found in 7.
43 */
44public class WekaLocalTraining2 extends WekaBaseTraining2 implements ITrainingStrategy {
45       
46        private final TraindatasetCluster classifier = new TraindatasetCluster();
47       
48        @Override
49        public Classifier getClassifier() {
50                return classifier;
51        }
52       
53        @Override
54        public void apply(Instances traindata) {
55                PrintStream errStr      = System.err;
56                System.setErr(new PrintStream(new NullOutputStream()));
57                try {
58                        classifier.buildClassifier(traindata);
59                } catch (Exception e) {
60                        throw new RuntimeException(e);
61                } finally {
62                        System.setErr(errStr);
63                }
64        }
65       
66       
67        public class TraindatasetCluster extends AbstractClassifier {
68               
69                private static final long serialVersionUID = 1L;
70               
71                /* classifier per cluster */
72                private HashMap<Integer, Classifier> cclassifier = new HashMap<Integer, Classifier>();
73               
74                /* instances per cluster */
75                private HashMap<Integer, Instances> ctraindata = new HashMap<Integer, Instances>();
76               
77                /* holds the instances and indices of the pivot objects of the Fastmap calculation in buildClassifier*/
78                private HashMap<Integer, Instance> cpivots = new HashMap<Integer, Instance>();
79               
80                /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension]*/
81                private int[][] cpivotindices = new int[2][2];
82
83                /* holds the sizes of the cluster multiple "boxes" per cluster */
84                private HashMap<Integer, ArrayList<Double[][]>> csize;
85               
86                private boolean show_biggest = true;
87               
88                private int CFOUND = 0;
89                private int CNOTFOUND = 0;
90               
91               
92                private Instance createInstance(Instances instances, Instance instance) {
93                        // attributes for feeding instance to classifier
94                        Set<String> attributeNames = new HashSet<>();
95                        for( int j=0; j<instances.numAttributes(); j++ ) {
96                                attributeNames.add(instances.attribute(j).name());
97                        }
98                       
99                        double[] values = new double[instances.numAttributes()];
100                        int index = 0;
101                        for( int j=0; j<instance.numAttributes(); j++ ) {
102                                if( attributeNames.contains(instance.attribute(j).name())) {
103                                        values[index] = instance.value(j);
104                                        index++;
105                                }
106                        }
107                       
108                        Instances tmp = new Instances(instances);
109                        tmp.clear();
110                        Instance instCopy = new DenseInstance(instance.weight(), values);
111                        instCopy.setDataset(tmp);
112                       
113                        return instCopy;
114                }
115               
116                /**
117                 * Because Fastmap saves only the image not the values of the attributes it used
118                 * we can not use the old data directly to classify single instances to clusters.
119                 *
120                 * To classify a single instance we do a new fastmap computation with only the instance and
121                 * the old pivot elements.
122                 *
123                 * After that we find the cluster with our fastmap result for x and y.
124                 */
125                @Override
126                public double classifyInstance(Instance instance) {
127                       
128                        double ret = 0;
129                        try {
130                                // classinstance gets passed to classifier
131                                Instances traindata = ctraindata.get(0);
132                                Instance classInstance = createInstance(traindata, instance);
133
134                                // this one keeps the class attribute
135                                Instances traindata2 = ctraindata.get(1); 
136                               
137                                // remove class attribute before clustering
138                                Remove filter = new Remove();
139                                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
140                                filter.setInputFormat(traindata);
141                                traindata = Filter.useFilter(traindata, filter);
142                                Instance clusterInstance = createInstance(traindata, instance);
143                               
144                                Fastmap FMAP = new Fastmap(2);
145                                EuclideanDistance dist = new EuclideanDistance(traindata);
146                               
147                               
148                                // we set our pivot indices [x=0,y=1][dimension]
149                                int[][] npivotindices = new int[2][2];
150                                npivotindices[0][0] = 1;
151                                npivotindices[1][0] = 2;
152                                npivotindices[0][1] = 3;
153                                npivotindices[1][1] = 4;
154                               
155                                // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify)
156                                // the instance we want to classify comes first after that the pivot elements in the order defined above
157                                double[][] distmat = new double[2*FMAP.target_dims+1][2*FMAP.target_dims+1];
158                                distmat[0][0] = 0;
159                                distmat[0][1] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][0]));
160                                distmat[0][2] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][0]));
161                                distmat[0][3] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][1]));
162                                distmat[0][4] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][1]));
163                               
164                                distmat[1][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), clusterInstance);
165                                distmat[1][1] = 0;
166                                distmat[1][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][0]));
167                                distmat[1][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[0][1]));
168                                distmat[1][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][1]));
169                               
170                                distmat[2][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), clusterInstance);
171                                distmat[2][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][0]));
172                                distmat[2][2] = 0;
173                                distmat[2][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][1]));
174                                distmat[2][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[1][1]));
175                               
176                                distmat[3][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), clusterInstance);
177                                distmat[3][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[0][0]));
178                                distmat[3][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][0]));
179                                distmat[3][3] = 0;
180                                distmat[3][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][1]));
181
182                                distmat[4][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), clusterInstance);
183                                distmat[4][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][0]));
184                                distmat[4][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[1][0]));
185                                distmat[4][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][1]));
186                                distmat[4][4] = 0;
187                               
188                               
189                                /* debug output: show biggest distance found within the new distance matrix
190                                double biggest = 0;
191                                for(int i=0; i < distmat.length; i++) {
192                                        for(int j=0; j < distmat[0].length; j++) {
193                                                if(biggest < distmat[i][j]) {
194                                                        biggest = distmat[i][j];
195                                                }
196                                        }
197                                }
198                                if(this.show_biggest) {
199                                        Console.traceln(Level.INFO, String.format(""+clusterInstance));
200                                        Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest));
201                                        this.show_biggest = false;
202                                }
203                                */
204
205                                FMAP.setDistmat(distmat);
206                                FMAP.setPivots(npivotindices);
207                                FMAP.calculate();
208                                double[][] x = FMAP.getX();
209                                double[] proj = x[0];
210
211                                // debug output: show the calculated distance matrix, our result vektor for the instance and the complete result matrix
212                                /*
213                                Console.traceln(Level.INFO, "distmat:");
214                            for(int i=0; i<distmat.length; i++){
215                                for(int j=0; j<distmat[0].length; j++){
216                                        Console.trace(Level.INFO, String.format("%20s", distmat[i][j]));
217                                }
218                                Console.traceln(Level.INFO, "");
219                            }
220                           
221                            Console.traceln(Level.INFO, "vector:");
222                            for(int i=0; i < proj.length; i++) {
223                                Console.trace(Level.INFO, String.format("%20s", proj[i]));
224                            }
225                            Console.traceln(Level.INFO, "");
226                           
227                                Console.traceln(Level.INFO, "resultmat:");
228                            for(int i=0; i<x.length; i++){
229                                for(int j=0; j<x[0].length; j++){
230                                        Console.trace(Level.INFO, String.format("%20s", x[i][j]));
231                                }
232                                Console.traceln(Level.INFO, "");
233                            }
234                            */
235                               
236                                // TODO: can we be in more cluster than one?
237                                // now we iterate over all clusters (well, boxes of sizes per cluster really) and save the number of the
238                                // cluster in which we are
239                                int cnumber;
240                                int found_cnumber = -1;
241                                Iterator<Integer> clusternumber = this.csize.keySet().iterator();
242                                while ( clusternumber.hasNext() && found_cnumber == -1) {
243                                        cnumber = clusternumber.next();
244                                       
245                                        // now iterate over the boxes of the cluster and hope we find one (cluster could have been removed)
246                                        // or we are too far away from any cluster
247                                        for ( int box=0; box < this.csize.get(cnumber).size(); box++ ) {
248                                                Double[][] current = this.csize.get(cnumber).get(box);
249                                               
250                                                if(proj[0] >= current[0][0] && proj[0] <= current[0][1] &&  // x
251                                                   proj[1] >= current[1][0] && proj[1] <= current[1][1]) {  // y
252                                                        found_cnumber = cnumber;
253                                                }
254                                        }
255                                }
256                               
257                                // we want to count how often we are really inside a cluster
258                                if ( found_cnumber == -1 ) {
259                                        CNOTFOUND += 1;
260                                }else {
261                                        CFOUND += 1;
262                                }
263
264                                // now it can happen that we dont find a cluster because we deleted it previously (too few instances)
265                                // or we get bigger distance measures from weka so that we are completely outside of our clusters.
266                                // in these cases we just find the nearest cluster to our instance and use it for classification.
267                                // to do that we use the EuclideanDistance again to compare our distance to all other Instances
268                                // then we take the cluster of the closest weka instance
269                                dist = new EuclideanDistance(traindata2);
270                                if( !this.ctraindata.containsKey(found_cnumber) ) {
271                                        double min_distance = 99999999;
272                                        clusternumber = ctraindata.keySet().iterator();
273                                        while ( clusternumber.hasNext() ) {
274                                                cnumber = clusternumber.next();
275                                                for(int i=0; i < ctraindata.get(cnumber).size(); i++) {
276                                                        if(dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) {
277                                                                found_cnumber = cnumber;
278                                                                min_distance = dist.distance(instance, ctraindata.get(cnumber).get(i));
279                                                        }
280                                                }
281                                        }
282                                }
283                               
284                                // here we have the cluster where an instance has the minimum distance between itself the
285                                // instance we want to classify
286                                // if we still have not found a cluster we exit because something is really wrong
287                                if( found_cnumber == -1 ) {
288                                        Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster with full search!"));
289                                        throw new RuntimeException("cluster not found with full search");
290                                }
291                               
292                                // classify the passed instance with the cluster we found and its training data
293                                ret = cclassifier.get(found_cnumber).classifyInstance(classInstance);
294                               
295                        }catch( Exception e ) {
296                                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
297                                throw new RuntimeException(e);
298                        }
299                        return ret;
300                }
301               
302                @Override
303                public void buildClassifier(Instances traindata) throws Exception {
304                       
305                        //Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + CNOTFOUND));
306                        this.show_biggest = true;
307                       
308                       
309                        // 1. copy traindata
310                        Instances train = new Instances(traindata);
311                        Instances train2 = new Instances(traindata);  // this one keeps the class attribute
312                       
313                        // 2. remove class attribute for clustering
314                        Remove filter = new Remove();
315                        filter.setAttributeIndices("" + (train.classIndex() + 1));
316                        filter.setInputFormat(train);
317                        train = Filter.useFilter(train, filter);
318                       
319                        // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1)
320                        double biggest = 0;
321                        EuclideanDistance dist = new EuclideanDistance(train);
322                        double[][] distmat = new double[train.size()][train.size()];
323                        for( int i=0; i < train.size(); i++ ) {
324                                for( int j=0; j < train.size(); j++ ) {
325                                        distmat[i][j] = dist.distance(train.get(i), train.get(j));
326                                        if( distmat[i][j] > biggest ) {
327                                                biggest = distmat[i][j];
328                                        }
329                                }
330                        }
331                        //Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest));
332                       
333                        // 4. run fastmap for 2 dimensions on the distance matrix
334                        Fastmap FMAP = new Fastmap(2);
335                        FMAP.setDistmat(distmat);
336                        FMAP.calculate();
337                       
338                        cpivotindices = FMAP.getPivots();
339                       
340                        double[][] X = FMAP.getX();
341                       
342                        // quadtree payload generation
343                        ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>();
344                   
345                        // die max und min brauchen wir für die größenangaben der sektoren
346                        double[] big = {0,0};
347                        double[] small = {9999999,99999999};
348                       
349                        // set quadtree payload values and get max and min x and y values for size
350                    for( int i=0; i<X.length; i++ ){
351                        if(X[i][0] >= big[0]) {
352                                big[0] = X[i][0];
353                        }
354                        if(X[i][1] >= big[1]) {
355                                big[1] = X[i][1];
356                        }
357                        if(X[i][0] <= small[0]) {
358                                small[0] = X[i][0];
359                        }
360                        if(X[i][1] <= small[1]) {
361                                small[1] = X[i][1];
362                        }
363                        QuadTreePayload<Instance> tmp = new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i));
364                        qtp.add(tmp);
365                    }
366                   
367                    Console.traceln(Level.INFO, String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")"));
368                   
369                    // 5. generate quadtree
370                    QuadTree TREE = new QuadTree(null, qtp);
371                    QuadTree.size = train.size();
372                    QuadTree.alpha = Math.sqrt(train.size());
373                    QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>();
374                    QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>();
375                   
376                    //Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + " size, Alpha: "+ QuadTree.alpha+ ""));
377                   
378                    // set the size and then split the tree recursively at the median value for x, y
379                    TREE.setSize(new double[] {small[0], big[0]}, new double[] {small[1], big[1]});
380                   
381                    // recursive split und grid clustering eher static
382                    TREE.recursiveSplit(TREE);
383                   
384                    // generate list of nodes sorted by density (childs only)
385                    ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE));
386                   
387                    // recursive grid clustering (tree pruning), the values are stored in ccluster
388                    TREE.gridClustering(l);
389                   
390                    // wir iterieren durch die cluster und sammeln uns die instanzen daraus
391                    //ctraindata.clear();
392                    for( int i=0; i < QuadTree.ccluster.size(); i++ ) {
393                        ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i);
394                       
395                        // i is the clusternumber
396                        // we only allow clusters with Instances > ALPHA, other clusters are not considered!
397                        //if(current.size() > QuadTree.alpha) {
398                        if( current.size() > 4 ) {
399                                for( int j=0; j < current.size(); j++ ) {
400                                        if( !ctraindata.containsKey(i) ) {
401                                                ctraindata.put(i, new Instances(train2));
402                                                ctraindata.get(i).delete();
403                                        }
404                                        ctraindata.get(i).add(current.get(j).getInst());
405                                }
406                        }else{
407                                Console.traceln(Level.INFO, String.format("drop cluster, only: " + current.size() + " instances"));
408                        }
409                    }
410                       
411                        // here we keep things we need later on
412                        // QuadTree sizes for later use
413                        this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize);
414               
415                        // pivot elements
416                        //this.cpivots.clear();
417                        for( int i=0; i < FMAP.PA[0].length; i++ ) {
418                                this.cpivots.put(FMAP.PA[0][i], (Instance)train.get(FMAP.PA[0][i]).copy());
419                        }
420                        for( int j=0; j < FMAP.PA[0].length; j++ ) {
421                                this.cpivots.put(FMAP.PA[1][j], (Instance)train.get(FMAP.PA[1][j]).copy());
422                        }
423                       
424                       
425                        /* debug output
426                        int pnumber;
427                        Iterator<Integer> pivotnumber = cpivots.keySet().iterator();
428                        while ( pivotnumber.hasNext() ) {
429                                pnumber = pivotnumber.next();
430                                Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ " inst: "+cpivots.get(pnumber)));
431                        }
432                        */
433                       
434                    // train one classifier per cluster, we get the clusternumber from the traindata
435                    int cnumber;
436                        Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
437                        //cclassifier.clear();
438                        while ( clusternumber.hasNext() ) {
439                                cnumber = clusternumber.next();
440                                cclassifier.put(cnumber,setupClassifier()); // das hier ist der eigentliche trainer
441                                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
442                                //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
443                                //Console.traceln(Level.INFO, String.format("" + ctraindata.get(cnumber).size() + " instances in cluster "+cnumber));
444                        }
445                       
446                        //Console.traceln(Level.INFO, String.format("num clusters: "+cclassifier.size()));
447                }
448        }
449       
450
451        /**
452         * Payload for the QuadTree.
453         * x and y are the calculated Fastmap values.
454         * T is a weka instance.
455         */
456        public class QuadTreePayload<T> {
457
458                public double x;
459                public double y;
460                private T inst;
461               
462                public QuadTreePayload(double x, double y, T value) {
463                        this.x = x;
464                        this.y = y;
465                        this.inst = value;
466                }
467               
468                public T getInst() {
469                        return this.inst;
470                }
471        }
472       
473       
474        /**
475         * Fastmap implementation
476         *
477         * Faloutsos, C., & Lin, K. I. (1995).
478         * FastMap: A fast algorithm for indexing, data-mining and visualization of traditional and multimedia datasets
479         * (Vol. 24, No. 2, pp. 163-174). ACM.
480         */
481        public class Fastmap {
482               
483                /*N x k Array, at the end, the i-th row will be the image of the i-th object*/
484                private double[][] X;
485               
486                /*2 x k pivot Array one pair per recursive call*/
487                private int[][] PA;
488               
489                /*Objects we got (distance matrix)*/
490                private double[][] O;
491               
492                /*column of X currently updated (also the dimension)*/
493                private int col = 0;
494               
495                /*number of dimensions we want*/
496                private int target_dims = 0;
497               
498                // if we already have the pivot elements
499                private boolean pivot_set = false;
500               
501
502                public Fastmap(int k) {
503                        this.target_dims = k;
504                }
505               
506                /**
507                 * Sets the distance matrix
508                 * and params that depend on this
509                 * @param O
510                 */
511                public void setDistmat(double[][] O) {
512                        this.O = O;
513                        int N = O.length;
514                        this.X = new double[N][this.target_dims];
515                        this.PA = new int[2][this.target_dims];
516                }
517               
518                /**
519                 * Set pivot elements, we need that to classify instances
520                 * after the calculation is complete (because we then want to reuse
521                 * only the pivot elements).
522                 *
523                 * @param pi
524                 */
525                public void setPivots(int[][] pi) {
526                        this.pivot_set = true;
527                        this.PA = pi;
528                }
529               
530                /**
531                 * Return the pivot elements that were chosen during the calculation
532                 *
533                 * @return
534                 */
535                public int[][] getPivots() {
536                        return this.PA;
537                }
538               
539                /**
540                 * The distance function for euclidean distance
541                 *
542                 * Acts according to equation 4 of the fastmap paper
543                 * 
544                 * @param x x index of x image (if k==0 x object)
545                 * @param y y index of y image (if k==0 y object)
546                 * @param kdimensionality
547                 * @return distance
548                 */
549                private double dist(int x, int y, int k) {
550                       
551                        // basis is object distance, we get this from our distance matrix
552                        double tmp = this.O[x][y] * this.O[x][y];
553                       
554                        // decrease by projections
555                        for( int i=0; i < k; i++ ) {
556                                double tmp2 = (this.X[x][i] - this.X[y][i]);
557                                tmp -= tmp2 * tmp2;
558                        }
559                       
560                        return Math.abs(tmp);
561                }
562
563                /**
564                 * Find the object farthest from the given index
565                 * This method is a helper Method for findDistandObjects
566                 *
567                 * @param index of the object
568                 * @return index of the farthest object from the given index
569                 */
570                private int findFarthest(int index) {
571                        double furthest = -1000000;
572                        int ret = 0;
573                       
574                        for( int i=0; i < O.length; i++ ) {
575                                double dist = this.dist(i, index, this.col);
576                                if( i != index && dist > furthest ) {
577                                        furthest = dist;
578                                        ret = i;
579                                }
580                        }
581                        return ret;
582                }
583               
584                /**
585                 * Finds the pivot objects
586                 *
587                 * This method is basically algorithm 1 of the fastmap paper.
588                 *
589                 * @return 2 indexes of the choosen pivot objects
590                 */
591                private int[] findDistantObjects() {
592                        // 1. choose object randomly
593                        Random r = new Random();
594                        int obj = r.nextInt(this.O.length);
595                       
596                        // 2. find farthest object from randomly chosen object
597                        int idx1 = this.findFarthest(obj);
598                       
599                        // 3. find farthest object from previously farthest object
600                        int idx2 = this.findFarthest(idx1);
601
602                        return new int[] {idx1, idx2};
603                }
604       
605                /**
606                 * Calculates the new k-vector values (projections)
607                 *
608                 * This is basically algorithm 2 of the fastmap paper.
609                 * We just added the possibility to pre-set the pivot elements because
610                 * we need to classify single instances after the computation is already done.
611                 *
612                 * @param dims dimensionality
613                 */
614                public void calculate() {
615                       
616                        for( int k=0; k < this.target_dims; k++ ) {
617                                // 2) choose pivot objects
618                                if ( !this.pivot_set ) {
619                                        int[] pivots = this.findDistantObjects();
620               
621                                        // 3) record ids of pivot objects
622                                        this.PA[0][this.col] = pivots[0];
623                                        this.PA[1][this.col] = pivots[1];
624                                }
625                               
626                                // 4) inter object distances are zero (this.X is initialized with 0 so we just continue)
627                                if( this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0 ) {
628                                        continue;
629                                }
630                               
631                                // 5) project the objects on the line between the pivots
632                                double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col);
633                                for( int i=0; i < this.O.length; i++ ) {
634                                       
635                                        double dix = this.dist(i, this.PA[0][this.col], this.col);
636                                        double diy = this.dist(i, this.PA[1][this.col], this.col);
637                                       
638                                        double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy));
639                                       
640                                        // save the projection
641                                        this.X[i][this.col] = tmp;
642                                }
643                               
644                                this.col += 1;
645                        }
646                }
647               
648                /**
649                 * returns the result matrix of the projections
650                 *
651                 * @return calculated result
652                 */
653                public double[][] getX() {
654                        return this.X;
655                }
656        }
657}
Note: See TracBrowser for help on using the repository browser.