source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalFQTraining.java

Last change on this file was 136, checked in by sherbold, 8 years ago
  • more code documentation
File size: 32.5 KB
Line 
1// Copyright 2015 Georg-August-Universität Göttingen, Germany
2//
3//   Licensed under the Apache License, Version 2.0 (the "License");
4//   you may not use this file except in compliance with the License.
5//   You may obtain a copy of the License at
6//
7//       http://www.apache.org/licenses/LICENSE-2.0
8//
9//   Unless required by applicable law or agreed to in writing, software
10//   distributed under the License is distributed on an "AS IS" BASIS,
11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12//   See the License for the specific language governing permissions and
13//   limitations under the License.
14
15package de.ugoe.cs.cpdp.training;
16
17import java.util.ArrayList;
18import java.util.HashMap;
19import java.util.HashSet;
20import java.util.Iterator;
21import java.util.Random;
22import java.util.Set;
23import java.util.logging.Level;
24
25import de.ugoe.cs.cpdp.training.QuadTree;
26import de.ugoe.cs.util.console.Console;
27import weka.classifiers.AbstractClassifier;
28import weka.classifiers.Classifier;
29import weka.core.DenseInstance;
30import weka.core.EuclideanDistance;
31import weka.core.Instance;
32import weka.core.Instances;
33import weka.filters.Filter;
34import weka.filters.unsupervised.attribute.Remove;
35
36/**
37 * <p>
38 * Trainer with reimplementation of WHERE clustering algorithm from: Tim Menzies, Andrew Butcher,
39 * David Cok, Andrian Marcus, Lucas Layman, Forrest Shull, Burak Turhan, Thomas Zimmermann,
40 * "Local versus Global Lessons for Defect Prediction and Effort Estimation," IEEE Transactions on
41 * Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013
42 * </p>
43 * <p>
44 * With WekaLocalFQTraining we do the following:
45 * <ol>
46 * <li>Run the Fastmap algorithm on all training data, let it calculate the 2 most significant
47 * dimensions and projections of each instance to these dimensions</li>
48 * <li>With these 2 dimensions we span a QuadTree which gets recursively split on median(x) and
49 * median(y) values.</li>
50 * <li>We cluster the QuadTree nodes together if they have similar density (50%)</li>
51 * <li>We save the clusters and their training data</li>
52 * <li>We only use clusters with > ALPHA instances (currently Math.sqrt(SIZE)), the rest is
53 * discarded with the training data of this cluster</li>
54 * <li>We train a Weka classifier for each cluster with the clusters training data</li>
55 * <li>We recalculate Fastmap distances for a single instance with the old pivots and then try to
56 * find a cluster containing the coords of the instance. If we can not find a cluster (due to coords
57 * outside of all clusters) we find the nearest cluster.</li>
58 * <li>We classify the Instance with the classifier and traindata from the Cluster we found in 7.
59 * </li>
60 * </p>
61 */
62public class WekaLocalFQTraining extends WekaBaseTraining implements ITrainingStrategy {
63
64    /**
65     * the classifier
66     */
67    private final TraindatasetCluster classifier = new TraindatasetCluster();
68
69    /*
70     * (non-Javadoc)
71     *
72     * @see de.ugoe.cs.cpdp.training.WekaBaseTraining#getClassifier()
73     */
74    @Override
75    public Classifier getClassifier() {
76        return classifier;
77    }
78
79    /*
80     * (non-Javadoc)
81     *
82     * @see de.ugoe.cs.cpdp.training.ITrainingStrategy#apply(weka.core.Instances)
83     */
84    @Override
85    public void apply(Instances traindata) {
86        try {
87            classifier.buildClassifier(traindata);
88        }
89        catch (Exception e) {
90            throw new RuntimeException(e);
91        }
92    }
93
94    /**
95     * <p>
96     * Weka classifier for the local model with WHERE clustering
97     * </p>
98     *
99     * @author Alexander Trautsch
100     */
101    public class TraindatasetCluster extends AbstractClassifier {
102
103        /**
104         * default serialization ID
105         */
106        private static final long serialVersionUID = 1L;
107
108        /**
109         * classifiers for each cluster
110         */
111        private HashMap<Integer, Classifier> cclassifier;
112
113        /**
114         * training data for each cluster
115         */
116        private HashMap<Integer, Instances> ctraindata;
117
118        /**
119         * holds the instances and indices of the pivot objects of the Fastmap calculation in
120         * buildClassifier
121         */
122        private HashMap<Integer, Instance> cpivots;
123
124        /**
125         * holds the indices of the pivot objects for x,y and the dimension [x,y][dimension]
126         */
127        private int[][] cpivotindices;
128
129        /**
130         * holds the sizes of the cluster multiple "boxes" per cluster
131         */
132        private HashMap<Integer, ArrayList<Double[][]>> csize;
133
134        /**
135         * debug variable
136         */
137        @SuppressWarnings("unused")
138        private boolean show_biggest = true;
139
140        /**
141         * debug variable
142         */
143        @SuppressWarnings("unused")
144        private int CFOUND = 0;
145
146        /**
147         * debug variable
148         */
149        @SuppressWarnings("unused")
150        private int CNOTFOUND = 0;
151
152        /**
153         * <p>
154         * copies an instance such that is is compatible with the local model
155         * </p>
156         *
157         * @param instances
158         *            instance format
159         * @param instance
160         *            instance that is copied
161         * @return
162         */
163        private Instance createInstance(Instances instances, Instance instance) {
164            // attributes for feeding instance to classifier
165            Set<String> attributeNames = new HashSet<>();
166            for (int j = 0; j < instances.numAttributes(); j++) {
167                attributeNames.add(instances.attribute(j).name());
168            }
169
170            double[] values = new double[instances.numAttributes()];
171            int index = 0;
172            for (int j = 0; j < instance.numAttributes(); j++) {
173                if (attributeNames.contains(instance.attribute(j).name())) {
174                    values[index] = instance.value(j);
175                    index++;
176                }
177            }
178
179            Instances tmp = new Instances(instances);
180            tmp.clear();
181            Instance instCopy = new DenseInstance(instance.weight(), values);
182            instCopy.setDataset(tmp);
183
184            return instCopy;
185        }
186
187        /**
188         * <p>
189         * Because Fastmap saves only the image not the values of the attributes it used we can not
190         * use the old data directly to classify single instances to clusters.
191         * </p>
192         * <p>
193         * To classify a single instance we do a new Fastmap computation with only the instance and
194         * the old pivot elements.
195         * </p>
196         * </p>
197         * After that we find the cluster with our Fastmap result for x and y.
198         * </p>
199         *
200         * @param instance
201         *            instance that is classified
202         * @see weka.classifiers.AbstractClassifier#classifyInstance(weka.core.Instance)
203         */
204        @Override
205        public double classifyInstance(Instance instance) {
206
207            double ret = 0;
208            try {
209                // classinstance gets passed to classifier
210                Instances traindata = ctraindata.get(0);
211                Instance classInstance = createInstance(traindata, instance);
212
213                // this one keeps the class attribute
214                Instances traindata2 = ctraindata.get(1);
215
216                // remove class attribute before clustering
217                Remove filter = new Remove();
218                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
219                filter.setInputFormat(traindata);
220                traindata = Filter.useFilter(traindata, filter);
221                Instance clusterInstance = createInstance(traindata, instance);
222
223                Fastmap FMAP = new Fastmap(2);
224                EuclideanDistance dist = new EuclideanDistance(traindata);
225
226                // we set our pivot indices [x=0,y=1][dimension]
227                int[][] npivotindices = new int[2][2];
228                npivotindices[0][0] = 1;
229                npivotindices[1][0] = 2;
230                npivotindices[0][1] = 3;
231                npivotindices[1][1] = 4;
232
233                // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify)
234                // the instance we want to classify comes first after that the pivot elements in the
235                // order defined above
236                double[][] distmat = new double[2 * FMAP.target_dims + 1][2 * FMAP.target_dims + 1];
237                distmat[0][0] = 0;
238                distmat[0][1] = dist.distance(clusterInstance,
239                                              this.cpivots.get((Integer) this.cpivotindices[0][0]));
240                distmat[0][2] = dist.distance(clusterInstance,
241                                              this.cpivots.get((Integer) this.cpivotindices[1][0]));
242                distmat[0][3] = dist.distance(clusterInstance,
243                                              this.cpivots.get((Integer) this.cpivotindices[0][1]));
244                distmat[0][4] = dist.distance(clusterInstance,
245                                              this.cpivots.get((Integer) this.cpivotindices[1][1]));
246
247                distmat[1][0] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
248                                              clusterInstance);
249                distmat[1][1] = 0;
250                distmat[1][2] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
251                                              this.cpivots.get((Integer) this.cpivotindices[1][0]));
252                distmat[1][3] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
253                                              this.cpivots.get((Integer) this.cpivotindices[0][1]));
254                distmat[1][4] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][0]),
255                                              this.cpivots.get((Integer) this.cpivotindices[1][1]));
256
257                distmat[2][0] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
258                                              clusterInstance);
259                distmat[2][1] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
260                                              this.cpivots.get((Integer) this.cpivotindices[0][0]));
261                distmat[2][2] = 0;
262                distmat[2][3] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
263                                              this.cpivots.get((Integer) this.cpivotindices[0][1]));
264                distmat[2][4] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][0]),
265                                              this.cpivots.get((Integer) this.cpivotindices[1][1]));
266
267                distmat[3][0] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
268                                              clusterInstance);
269                distmat[3][1] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
270                                              this.cpivots.get((Integer) this.cpivotindices[0][0]));
271                distmat[3][2] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
272                                              this.cpivots.get((Integer) this.cpivotindices[1][0]));
273                distmat[3][3] = 0;
274                distmat[3][4] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[0][1]),
275                                              this.cpivots.get((Integer) this.cpivotindices[1][1]));
276
277                distmat[4][0] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
278                                              clusterInstance);
279                distmat[4][1] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
280                                              this.cpivots.get((Integer) this.cpivotindices[0][0]));
281                distmat[4][2] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
282                                              this.cpivots.get((Integer) this.cpivotindices[1][0]));
283                distmat[4][3] = dist.distance(this.cpivots.get((Integer) this.cpivotindices[1][1]),
284                                              this.cpivots.get((Integer) this.cpivotindices[0][1]));
285                distmat[4][4] = 0;
286
287                /*
288                 * debug output: show biggest distance found within the new distance matrix double
289                 * biggest = 0; for(int i=0; i < distmat.length; i++) { for(int j=0; j <
290                 * distmat[0].length; j++) { if(biggest < distmat[i][j]) { biggest = distmat[i][j];
291                 * } } } if(this.show_biggest) { Console.traceln(Level.INFO,
292                 * String.format(""+clusterInstance)); Console.traceln(Level.INFO, String.format(
293                 * "biggest distances: "+ biggest)); this.show_biggest = false; }
294                 */
295
296                FMAP.setDistmat(distmat);
297                FMAP.setPivots(npivotindices);
298                FMAP.calculate();
299                double[][] x = FMAP.getX();
300                double[] proj = x[0];
301
302                // debug output: show the calculated distance matrix, our result vektor for the
303                // instance and the complete result matrix
304                /*
305                 * Console.traceln(Level.INFO, "distmat:"); for(int i=0; i<distmat.length; i++){
306                 * for(int j=0; j<distmat[0].length; j++){ Console.trace(Level.INFO,
307                 * String.format("%20s", distmat[i][j])); } Console.traceln(Level.INFO, ""); }
308                 *
309                 * Console.traceln(Level.INFO, "vector:"); for(int i=0; i < proj.length; i++) {
310                 * Console.trace(Level.INFO, String.format("%20s", proj[i])); }
311                 * Console.traceln(Level.INFO, "");
312                 *
313                 * Console.traceln(Level.INFO, "resultmat:"); for(int i=0; i<x.length; i++){ for(int
314                 * j=0; j<x[0].length; j++){ Console.trace(Level.INFO, String.format("%20s",
315                 * x[i][j])); } Console.traceln(Level.INFO, ""); }
316                 */
317
318                // now we iterate over all clusters (well, boxes of sizes per cluster really) and
319                // save the number of the
320                // cluster in which we are
321                int cnumber;
322                int found_cnumber = -1;
323                Iterator<Integer> clusternumber = this.csize.keySet().iterator();
324                while (clusternumber.hasNext() && found_cnumber == -1) {
325                    cnumber = clusternumber.next();
326
327                    // now iterate over the boxes of the cluster and hope we find one (cluster could
328                    // have been removed)
329                    // or we are too far away from any cluster because of the fastmap calculation
330                    // with the initial pivot objects
331                    for (int box = 0; box < this.csize.get(cnumber).size(); box++) {
332                        Double[][] current = this.csize.get(cnumber).get(box);
333
334                        if (proj[0] >= current[0][0] && proj[0] <= current[0][1] && // x
335                            proj[1] >= current[1][0] && proj[1] <= current[1][1])
336                        { // y
337                            found_cnumber = cnumber;
338                        }
339                    }
340                }
341
342                // we want to count how often we are really inside a cluster
343                // if ( found_cnumber == -1 ) {
344                // CNOTFOUND += 1;
345                // }else {
346                // CFOUND += 1;
347                // }
348
349                // now it can happen that we do not find a cluster because we deleted it previously
350                // (too few instances)
351                // or we get bigger distance measures from weka so that we are completely outside of
352                // our clusters.
353                // in these cases we just find the nearest cluster to our instance and use it for
354                // classification.
355                // to do that we use the EuclideanDistance again to compare our distance to all
356                // other Instances
357                // then we take the cluster of the closest weka instance
358                dist = new EuclideanDistance(traindata2);
359                if (!this.ctraindata.containsKey(found_cnumber)) {
360                    double min_distance = Double.MAX_VALUE;
361                    clusternumber = ctraindata.keySet().iterator();
362                    while (clusternumber.hasNext()) {
363                        cnumber = clusternumber.next();
364                        for (int i = 0; i < ctraindata.get(cnumber).size(); i++) {
365                            if (dist.distance(instance,
366                                              ctraindata.get(cnumber).get(i)) <= min_distance)
367                            {
368                                found_cnumber = cnumber;
369                                min_distance =
370                                    dist.distance(instance, ctraindata.get(cnumber).get(i));
371                            }
372                        }
373                    }
374                }
375
376                // here we have the cluster where an instance has the minimum distance between
377                // itself and the
378                // instance we want to classify
379                // if we still have not found a cluster we exit because something is really wrong
380                if (found_cnumber == -1) {
381                    Console.traceln(Level.INFO, String
382                        .format("ERROR matching instance to cluster with full search!"));
383                    throw new RuntimeException("cluster not found with full search");
384                }
385
386                // classify the passed instance with the cluster we found and its training data
387                ret = cclassifier.get(found_cnumber).classifyInstance(classInstance);
388
389            }
390            catch (Exception e) {
391                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
392                throw new RuntimeException(e);
393            }
394            return ret;
395        }
396
397        /*
398         * (non-Javadoc)
399         *
400         * @see weka.classifiers.Classifier#buildClassifier(weka.core.Instances)
401         */
402        @Override
403        public void buildClassifier(Instances traindata) throws Exception {
404
405            // Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " +
406            // CNOTFOUND));
407            this.show_biggest = true;
408
409            cclassifier = new HashMap<Integer, Classifier>();
410            ctraindata = new HashMap<Integer, Instances>();
411            cpivots = new HashMap<Integer, Instance>();
412            cpivotindices = new int[2][2];
413
414            // 1. copy traindata
415            Instances train = new Instances(traindata);
416            Instances train2 = new Instances(traindata); // this one keeps the class attribute
417
418            // 2. remove class attribute for clustering
419            Remove filter = new Remove();
420            filter.setAttributeIndices("" + (train.classIndex() + 1));
421            filter.setInputFormat(train);
422            train = Filter.useFilter(train, filter);
423
424            // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1)
425            double biggest = 0;
426            EuclideanDistance dist = new EuclideanDistance(train);
427            double[][] distmat = new double[train.size()][train.size()];
428            for (int i = 0; i < train.size(); i++) {
429                for (int j = 0; j < train.size(); j++) {
430                    distmat[i][j] = dist.distance(train.get(i), train.get(j));
431                    if (distmat[i][j] > biggest) {
432                        biggest = distmat[i][j];
433                    }
434                }
435            }
436            // Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest));
437
438            // 4. run fastmap for 2 dimensions on the distance matrix
439            Fastmap FMAP = new Fastmap(2);
440            FMAP.setDistmat(distmat);
441            FMAP.calculate();
442
443            cpivotindices = FMAP.getPivots();
444
445            double[][] X = FMAP.getX();
446
447            // quadtree payload generation
448            ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>();
449
450            // we need these for the sizes of the quadrants
451            double[] big =
452                { 0, 0 };
453            double[] small =
454                { Double.MAX_VALUE, Double.MAX_VALUE };
455
456            // set quadtree payload values and get max and min x and y values for size
457            for (int i = 0; i < X.length; i++) {
458                if (X[i][0] >= big[0]) {
459                    big[0] = X[i][0];
460                }
461                if (X[i][1] >= big[1]) {
462                    big[1] = X[i][1];
463                }
464                if (X[i][0] <= small[0]) {
465                    small[0] = X[i][0];
466                }
467                if (X[i][1] <= small[1]) {
468                    small[1] = X[i][1];
469                }
470                QuadTreePayload<Instance> tmp =
471                    new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i));
472                qtp.add(tmp);
473            }
474
475            // Console.traceln(Level.INFO,
476            // String.format("size for cluster ("+small[0]+","+small[1]+") -
477            // ("+big[0]+","+big[1]+")"));
478
479            // 5. generate quadtree
480            QuadTree TREE = new QuadTree(null, qtp);
481            QuadTree.size = train.size();
482            QuadTree.alpha = Math.sqrt(train.size());
483            QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>();
484            QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>();
485
486            // Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size +
487            // " size, Alpha: "+ QuadTree.alpha+ ""));
488
489            // set the size and then split the tree recursively at the median value for x, y
490            TREE.setSize(new double[]
491                { small[0], big[0] }, new double[]
492                { small[1], big[1] });
493
494            // recursive split und grid clustering eher static
495            QuadTree.recursiveSplit(TREE);
496
497            // generate list of nodes sorted by density (childs only)
498            ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE));
499
500            // recursive grid clustering (tree pruning), the values are stored in ccluster
501            TREE.gridClustering(l);
502
503            // wir iterieren durch die cluster und sammeln uns die instanzen daraus
504            // ctraindata.clear();
505            for (int i = 0; i < QuadTree.ccluster.size(); i++) {
506                ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i);
507
508                // i is the clusternumber
509                // we only allow clusters with Instances > ALPHA, other clusters are not considered!
510                // if(current.size() > QuadTree.alpha) {
511                if (current.size() > 4) {
512                    for (int j = 0; j < current.size(); j++) {
513                        if (!ctraindata.containsKey(i)) {
514                            ctraindata.put(i, new Instances(train2));
515                            ctraindata.get(i).delete();
516                        }
517                        ctraindata.get(i).add(current.get(j).getInst());
518                    }
519                }
520                else {
521                    Console.traceln(Level.INFO, String
522                        .format("drop cluster, only: " + current.size() + " instances"));
523                }
524            }
525
526            // here we keep things we need later on
527            // QuadTree sizes for later use (matching new instances)
528            this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize);
529
530            // pivot elements
531            // this.cpivots.clear();
532            for (int i = 0; i < FMAP.PA[0].length; i++) {
533                this.cpivots.put(FMAP.PA[0][i], (Instance) train.get(FMAP.PA[0][i]).copy());
534            }
535            for (int j = 0; j < FMAP.PA[0].length; j++) {
536                this.cpivots.put(FMAP.PA[1][j], (Instance) train.get(FMAP.PA[1][j]).copy());
537            }
538
539            /*
540             * debug output int pnumber; Iterator<Integer> pivotnumber =
541             * cpivots.keySet().iterator(); while ( pivotnumber.hasNext() ) { pnumber =
542             * pivotnumber.next(); Console.traceln(Level.INFO, String.format("pivot: "+pnumber+
543             * " inst: "+cpivots.get(pnumber))); }
544             */
545
546            // train one classifier per cluster, we get the cluster number from the traindata
547            int cnumber;
548            Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
549            // cclassifier.clear();
550
551            // int traindata_count = 0;
552            while (clusternumber.hasNext()) {
553                cnumber = clusternumber.next();
554                cclassifier.put(cnumber, setupClassifier()); // this is the classifier used for the
555                                                             // cluster
556                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
557                // Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
558                // traindata_count += ctraindata.get(cnumber).size();
559                // Console.traceln(Level.INFO,
560                // String.format("building classifier in cluster "+cnumber +" with "+
561                // ctraindata.get(cnumber).size() +" traindata instances"));
562            }
563
564            // add all traindata
565            // Console.traceln(Level.INFO, String.format("traindata in all clusters: " +
566            // traindata_count));
567        }
568    }
569
570    /**
571     * <p>
572     * Payload for the QuadTree. x and y are the calculated Fastmap values. T is a Weka instance.
573     * </p>
574     *
575     * @author Alexander Trautsch
576     * @param <T> type of the instance
577     */
578    public class QuadTreePayload<T> {
579
580        /**
581         * x-value
582         */
583        public final double x;
584
585        /**
586         * y-value
587         */
588        public final double y;
589
590        /**
591         * associated instance
592         */
593        private T inst;
594
595        /**
596         * <p>
597         * Constructor. Creates the payload.
598         * </p>
599         *
600         * @param x
601         *            x-value
602         * @param y
603         *            y-value
604         * @param value
605         *            associated instace
606         */
607        public QuadTreePayload(double x, double y, T value) {
608            this.x = x;
609            this.y = y;
610            this.inst = value;
611        }
612
613        /**
614         * <p>
615         * returns the instance
616         * </p>
617         *
618         * @return the instance
619         */
620        public T getInst() {
621            return this.inst;
622        }
623    }
624
625    /**
626     * <p>
627     * Fastmap implementation after:<br>
628     * * Faloutsos, C., & Lin, K. I. (1995). FastMap: A fast algorithm for indexing, data-mining and
629     * visualization of traditional and multimedia datasets (Vol. 24, No. 2, pp. 163-174). ACM.
630     * </p>
631     */
632    public class Fastmap {
633
634        /**
635         * N x k Array, at the end, the i-th row will be the image of the i-th object
636         */
637        private double[][] X;
638
639        /**
640         * 2 x k pivot Array one pair per recursive call
641         */
642        private int[][] PA;
643
644        /**
645         * Objects we got (distance matrix)
646         */
647        private double[][] O;
648
649        /**
650         * column of X currently updated (also the dimension)
651         */
652        private int col = 0;
653
654        /**
655         * number of dimensions we want
656         */
657        private int target_dims = 0;
658
659        /**
660         * if we already have the pivot elements
661         */
662        private boolean pivot_set = false;
663
664        /**
665         * <p>
666         * Constructor. Creates a new Fastmap object.
667         * </p>
668         *
669         * @param k
670         */
671        public Fastmap(int k) {
672            this.target_dims = k;
673        }
674
675        /**
676         * <p>
677         * Sets the distance matrix and params that depend on this.
678         * </p>
679         *
680         * @param O
681         *            distance matrix
682         */
683        public void setDistmat(double[][] O) {
684            this.O = O;
685            int N = O.length;
686            this.X = new double[N][this.target_dims];
687            this.PA = new int[2][this.target_dims];
688        }
689
690        /**
691         * <p>
692         * Set pivot elements, we need that to classify instances after the calculation is complete
693         * (because we then want to reuse only the pivot elements).
694         * </p>
695         *
696         * @param pi
697         *            the pivots
698         */
699        public void setPivots(int[][] pi) {
700            this.pivot_set = true;
701            this.PA = pi;
702        }
703
704        /**
705         * <p>
706         * Return the pivot elements that were chosen during the calculation
707         * </p>
708         *
709         * @return the pivots
710         */
711        public int[][] getPivots() {
712            return this.PA;
713        }
714
715        /**
716         * <p>
717         * The distance function for euclidean distance. Acts according to equation 4 of the Fastmap
718         * paper.
719         * </p>
720         *
721         * @param x
722         *            x index of x image (if k==0 x object)
723         * @param y
724         *            y index of y image (if k==0 y object)
725         * @param k
726         *            dimensionality
727         * @return the distance
728         */
729        private double dist(int x, int y, int k) {
730
731            // basis is object distance, we get this from our distance matrix
732            double tmp = this.O[x][y] * this.O[x][y];
733
734            // decrease by projections
735            for (int i = 0; i < k; i++) {
736                double tmp2 = (this.X[x][i] - this.X[y][i]);
737                tmp -= tmp2 * tmp2;
738            }
739
740            return Math.abs(tmp);
741        }
742
743        /**
744         * <p>
745         * Find the object farthest from the given index. This method is a helper Method for
746         * findDistandObjects.
747         * </p>
748         *
749         * @param index
750         *            of the object
751         * @return index of the farthest object from the given index
752         */
753        private int findFarthest(int index) {
754            double furthest = Double.MIN_VALUE;
755            int ret = 0;
756
757            for (int i = 0; i < O.length; i++) {
758                double dist = this.dist(i, index, this.col);
759                if (i != index && dist > furthest) {
760                    furthest = dist;
761                    ret = i;
762                }
763            }
764            return ret;
765        }
766
767        /**
768         * <p>
769         * Finds the pivot objects. This method is basically algorithm 1 of the Fastmap paper.
770         * </p>
771         *
772         * @return 2 indexes of the chosen pivot objects
773         */
774        private int[] findDistantObjects() {
775            // 1. choose object randomly
776            Random r = new Random();
777            int obj = r.nextInt(this.O.length);
778
779            // 2. find farthest object from randomly chosen object
780            int idx1 = this.findFarthest(obj);
781
782            // 3. find farthest object from previously farthest object
783            int idx2 = this.findFarthest(idx1);
784
785            return new int[]
786                { idx1, idx2 };
787        }
788
789        /**
790         * <p>
791         * Calculates the new k-vector values (projections) This is basically algorithm 2 of the
792         * fastmap paper. We just added the possibility to pre-set the pivot elements because we
793         * need to classify single instances after the computation is already done.
794         * </p>
795         */
796        public void calculate() {
797
798            for (int k = 0; k < this.target_dims; k++) {
799                // 2) choose pivot objects
800                if (!this.pivot_set) {
801                    int[] pivots = this.findDistantObjects();
802
803                    // 3) record ids of pivot objects
804                    this.PA[0][this.col] = pivots[0];
805                    this.PA[1][this.col] = pivots[1];
806                }
807
808                // 4) inter object distances are zero (this.X is initialized with 0 so we just
809                // continue)
810                if (this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0) {
811                    continue;
812                }
813
814                // 5) project the objects on the line between the pivots
815                double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col);
816                for (int i = 0; i < this.O.length; i++) {
817
818                    double dix = this.dist(i, this.PA[0][this.col], this.col);
819                    double diy = this.dist(i, this.PA[1][this.col], this.col);
820
821                    double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy));
822
823                    // save the projection
824                    this.X[i][this.col] = tmp;
825                }
826
827                this.col += 1;
828            }
829        }
830
831        /**
832         * <p>
833         * returns the result matrix of the projections
834         * </p>
835         *
836         * @return calculated result
837         */
838        public double[][] getX() {
839            return this.X;
840        }
841    }
842}
Note: See TracBrowser for help on using the repository browser.