source: trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalTraining2.java @ 21

Last change on this file since 21 was 21, checked in by sherbold, 10 years ago
  • bug fixes for local trainers
File size: 23.7 KB
Line 
1package de.ugoe.cs.cpdp.training;
2
3import java.io.PrintStream;
4import java.util.ArrayList;
5import java.util.HashMap;
6import java.util.HashSet;
7import java.util.Iterator;
8import java.util.Random;
9import java.util.Set;
10import java.util.logging.Level;
11
12import org.apache.commons.io.output.NullOutputStream;
13
14import de.ugoe.cs.cpdp.training.QuadTree;
15import de.ugoe.cs.util.console.Console;
16import weka.classifiers.AbstractClassifier;
17import weka.classifiers.Classifier;
18import weka.core.DenseInstance;
19import weka.core.EuclideanDistance;
20import weka.core.Instance;
21import weka.core.Instances;
22import weka.filters.Filter;
23import weka.filters.unsupervised.attribute.Remove;
24
25/**
26 * Trainer with reimplementation of WHERE clustering algorithm from:
27 * Tim Menzies, Andrew Butcher, David Cok, Andrian Marcus, Lucas Layman,
28 * Forrest Shull, Burak Turhan, Thomas Zimmermann,
29 * "Local versus Global Lessons for Defect Prediction and Effort Estimation,"
30 * IEEE Transactions on Software Engineering, vol. 39, no. 6, pp. 822-834, June, 2013 
31 *
32 * With WekaLocalTraining2 we do the following:
33 * 1) Run the Fastmap algorithm on all training data, let it calculate the 2 most significant
34 *    dimensions and projections of each instance to these dimensions
35 * 2) With these 2 dimensions we span a QuadTree which gets recursively split on median(x) and median(y) values.
36 * 3) We cluster the QuadTree nodes together if they have similar density (50%)
37 * 4) We save the clusters and their training data
38 * 5) We only use clusters with > ALPHA instances (currently Math.sqrt(SIZE)), rest is discarded with the training data of this cluster
39 * 6) We train a Weka classifier for each cluster with the clusters training data
40 * 7) We recalculate Fastmap distances for a single instance with the old pivots and then try to find a cluster containing the coords of the instance.
41 * 7.1.) If we can not find a cluster (due to coords outside of all clusters) we find the nearest cluster.
42 * 8) We classify the Instance with the classifier and traindata from the Cluster we found in 7.
43 */
44public class WekaLocalTraining2 extends WekaBaseTraining2 implements ITrainingStrategy {
45       
46        private final TraindatasetCluster classifier = new TraindatasetCluster();
47       
48        @Override
49        public Classifier getClassifier() {
50                return classifier;
51        }
52       
53        @Override
54        public void apply(Instances traindata) {
55                PrintStream errStr      = System.err;
56                System.setErr(new PrintStream(new NullOutputStream()));
57                try {
58                        classifier.buildClassifier(traindata);
59                } catch (Exception e) {
60                        throw new RuntimeException(e);
61                } finally {
62                        System.setErr(errStr);
63                }
64        }
65       
66       
67        public class TraindatasetCluster extends AbstractClassifier {
68               
69                private static final long serialVersionUID = 1L;
70               
71                /* classifier per cluster */
72                private HashMap<Integer, Classifier> cclassifier;
73               
74                /* instances per cluster */
75                private HashMap<Integer, Instances> ctraindata;
76               
77                /* holds the instances and indices of the pivot objects of the Fastmap calculation in buildClassifier*/
78                private HashMap<Integer, Instance> cpivots;
79               
80                /* holds the indices of the pivot objects for x,y and the dimension [x,y][dimension]*/
81                private int[][] cpivotindices;
82
83                /* holds the sizes of the cluster multiple "boxes" per cluster */
84                private HashMap<Integer, ArrayList<Double[][]>> csize;
85               
86                /* debug vars */
87                @SuppressWarnings("unused")
88                private boolean show_biggest = true;
89               
90                @SuppressWarnings("unused")
91                private int CFOUND = 0;
92                @SuppressWarnings("unused")
93                private int CNOTFOUND = 0;
94               
95               
96                private Instance createInstance(Instances instances, Instance instance) {
97                        // attributes for feeding instance to classifier
98                        Set<String> attributeNames = new HashSet<>();
99                        for( int j=0; j<instances.numAttributes(); j++ ) {
100                                attributeNames.add(instances.attribute(j).name());
101                        }
102                       
103                        double[] values = new double[instances.numAttributes()];
104                        int index = 0;
105                        for( int j=0; j<instance.numAttributes(); j++ ) {
106                                if( attributeNames.contains(instance.attribute(j).name())) {
107                                        values[index] = instance.value(j);
108                                        index++;
109                                }
110                        }
111                       
112                        Instances tmp = new Instances(instances);
113                        tmp.clear();
114                        Instance instCopy = new DenseInstance(instance.weight(), values);
115                        instCopy.setDataset(tmp);
116                       
117                        return instCopy;
118                }
119               
120                /**
121                 * Because Fastmap saves only the image not the values of the attributes it used
122                 * we can not use the old data directly to classify single instances to clusters.
123                 *
124                 * To classify a single instance we do a new fastmap computation with only the instance and
125                 * the old pivot elements.
126                 *
127                 * After that we find the cluster with our fastmap result for x and y.
128                 */
129                @Override
130                public double classifyInstance(Instance instance) {
131                       
132                        double ret = 0;
133                        try {
134                                // classinstance gets passed to classifier
135                                Instances traindata = ctraindata.get(0);
136                                Instance classInstance = createInstance(traindata, instance);
137
138                                // this one keeps the class attribute
139                                Instances traindata2 = ctraindata.get(1); 
140                               
141                                // remove class attribute before clustering
142                                Remove filter = new Remove();
143                                filter.setAttributeIndices("" + (traindata.classIndex() + 1));
144                                filter.setInputFormat(traindata);
145                                traindata = Filter.useFilter(traindata, filter);
146                                Instance clusterInstance = createInstance(traindata, instance);
147                               
148                                Fastmap FMAP = new Fastmap(2);
149                                EuclideanDistance dist = new EuclideanDistance(traindata);
150                               
151                                // we set our pivot indices [x=0,y=1][dimension]
152                                int[][] npivotindices = new int[2][2];
153                                npivotindices[0][0] = 1;
154                                npivotindices[1][0] = 2;
155                                npivotindices[0][1] = 3;
156                                npivotindices[1][1] = 4;
157                               
158                                // build temp dist matrix (2 pivots per dimension + 1 instance we want to classify)
159                                // the instance we want to classify comes first after that the pivot elements in the order defined above
160                                double[][] distmat = new double[2*FMAP.target_dims+1][2*FMAP.target_dims+1];
161                                distmat[0][0] = 0;
162                                distmat[0][1] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][0]));
163                                distmat[0][2] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][0]));
164                                distmat[0][3] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[0][1]));
165                                distmat[0][4] = dist.distance(clusterInstance, this.cpivots.get((Integer)this.cpivotindices[1][1]));
166                               
167                                distmat[1][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), clusterInstance);
168                                distmat[1][1] = 0;
169                                distmat[1][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][0]));
170                                distmat[1][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[0][1]));
171                                distmat[1][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][0]), this.cpivots.get((Integer)this.cpivotindices[1][1]));
172                               
173                                distmat[2][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), clusterInstance);
174                                distmat[2][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][0]));
175                                distmat[2][2] = 0;
176                                distmat[2][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[0][1]));
177                                distmat[2][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][0]), this.cpivots.get((Integer)this.cpivotindices[1][1]));
178                               
179                                distmat[3][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), clusterInstance);
180                                distmat[3][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[0][0]));
181                                distmat[3][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][0]));
182                                distmat[3][3] = 0;
183                                distmat[3][4] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[0][1]), this.cpivots.get((Integer)this.cpivotindices[1][1]));
184
185                                distmat[4][0] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), clusterInstance);
186                                distmat[4][1] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][0]));
187                                distmat[4][2] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[1][0]));
188                                distmat[4][3] = dist.distance(this.cpivots.get((Integer)this.cpivotindices[1][1]), this.cpivots.get((Integer)this.cpivotindices[0][1]));
189                                distmat[4][4] = 0;
190                               
191                               
192                                /* debug output: show biggest distance found within the new distance matrix
193                                double biggest = 0;
194                                for(int i=0; i < distmat.length; i++) {
195                                        for(int j=0; j < distmat[0].length; j++) {
196                                                if(biggest < distmat[i][j]) {
197                                                        biggest = distmat[i][j];
198                                                }
199                                        }
200                                }
201                                if(this.show_biggest) {
202                                        Console.traceln(Level.INFO, String.format(""+clusterInstance));
203                                        Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest));
204                                        this.show_biggest = false;
205                                }
206                                */
207
208                                FMAP.setDistmat(distmat);
209                                FMAP.setPivots(npivotindices);
210                                FMAP.calculate();
211                                double[][] x = FMAP.getX();
212                                double[] proj = x[0];
213
214                                // debug output: show the calculated distance matrix, our result vektor for the instance and the complete result matrix
215                                /*
216                                Console.traceln(Level.INFO, "distmat:");
217                            for(int i=0; i<distmat.length; i++){
218                                for(int j=0; j<distmat[0].length; j++){
219                                        Console.trace(Level.INFO, String.format("%20s", distmat[i][j]));
220                                }
221                                Console.traceln(Level.INFO, "");
222                            }
223                           
224                            Console.traceln(Level.INFO, "vector:");
225                            for(int i=0; i < proj.length; i++) {
226                                Console.trace(Level.INFO, String.format("%20s", proj[i]));
227                            }
228                            Console.traceln(Level.INFO, "");
229                           
230                                Console.traceln(Level.INFO, "resultmat:");
231                            for(int i=0; i<x.length; i++){
232                                for(int j=0; j<x[0].length; j++){
233                                        Console.trace(Level.INFO, String.format("%20s", x[i][j]));
234                                }
235                                Console.traceln(Level.INFO, "");
236                            }
237                            */
238                               
239                                // now we iterate over all clusters (well, boxes of sizes per cluster really) and save the number of the
240                                // cluster in which we are
241                                int cnumber;
242                                int found_cnumber = -1;
243                                Iterator<Integer> clusternumber = this.csize.keySet().iterator();
244                                while ( clusternumber.hasNext() && found_cnumber == -1) {
245                                        cnumber = clusternumber.next();
246                                       
247                                        // now iterate over the boxes of the cluster and hope we find one (cluster could have been removed)
248                                        // or we are too far away from any cluster because of the fastmap calculation with the initial pivot objects
249                                        for ( int box=0; box < this.csize.get(cnumber).size(); box++ ) {
250                                                Double[][] current = this.csize.get(cnumber).get(box);
251                                               
252                                                if(proj[0] >= current[0][0] && proj[0] <= current[0][1] &&  // x
253                                                   proj[1] >= current[1][0] && proj[1] <= current[1][1]) {  // y
254                                                        found_cnumber = cnumber;
255                                                }
256                                        }
257                                }
258                               
259                                // we want to count how often we are really inside a cluster
260                                //if ( found_cnumber == -1 ) {
261                                //      CNOTFOUND += 1;
262                                //}else {
263                                //      CFOUND += 1;
264                                //}
265
266                                // now it can happen that we do not find a cluster because we deleted it previously (too few instances)
267                                // or we get bigger distance measures from weka so that we are completely outside of our clusters.
268                                // in these cases we just find the nearest cluster to our instance and use it for classification.
269                                // to do that we use the EuclideanDistance again to compare our distance to all other Instances
270                                // then we take the cluster of the closest weka instance
271                                dist = new EuclideanDistance(traindata2);
272                                if( !this.ctraindata.containsKey(found_cnumber) ) {
273                                        double min_distance = Double.MAX_VALUE;
274                                        clusternumber = ctraindata.keySet().iterator();
275                                        while ( clusternumber.hasNext() ) {
276                                                cnumber = clusternumber.next();
277                                                for(int i=0; i < ctraindata.get(cnumber).size(); i++) {
278                                                        if(dist.distance(instance, ctraindata.get(cnumber).get(i)) <= min_distance) {
279                                                                found_cnumber = cnumber;
280                                                                min_distance = dist.distance(instance, ctraindata.get(cnumber).get(i));
281                                                        }
282                                                }
283                                        }
284                                }
285                               
286                                // here we have the cluster where an instance has the minimum distance between itself and the
287                                // instance we want to classify
288                                // if we still have not found a cluster we exit because something is really wrong
289                                if( found_cnumber == -1 ) {
290                                        Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster with full search!"));
291                                        throw new RuntimeException("cluster not found with full search");
292                                }
293                               
294                                // classify the passed instance with the cluster we found and its training data
295                                ret = cclassifier.get(found_cnumber).classifyInstance(classInstance);
296                               
297                        }catch( Exception e ) {
298                                Console.traceln(Level.INFO, String.format("ERROR matching instance to cluster!"));
299                                throw new RuntimeException(e);
300                        }
301                        return ret;
302                }
303               
304                @Override
305                public void buildClassifier(Instances traindata) throws Exception {
306                       
307                        //Console.traceln(Level.INFO, String.format("found: "+ CFOUND + ", notfound: " + CNOTFOUND));
308                        this.show_biggest = true;
309                       
310                        cclassifier = new HashMap<Integer, Classifier>();
311                        ctraindata = new HashMap<Integer, Instances>();
312                        cpivots = new HashMap<Integer, Instance>();
313                        cpivotindices = new int[2][2];
314                       
315                        // 1. copy traindata
316                        Instances train = new Instances(traindata);
317                        Instances train2 = new Instances(traindata);  // this one keeps the class attribute
318                       
319                        // 2. remove class attribute for clustering
320                        Remove filter = new Remove();
321                        filter.setAttributeIndices("" + (train.classIndex() + 1));
322                        filter.setInputFormat(train);
323                        train = Filter.useFilter(train, filter);
324                       
325                        // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1)
326                        double biggest = 0;
327                        EuclideanDistance dist = new EuclideanDistance(train);
328                        double[][] distmat = new double[train.size()][train.size()];
329                        for( int i=0; i < train.size(); i++ ) {
330                                for( int j=0; j < train.size(); j++ ) {
331                                        distmat[i][j] = dist.distance(train.get(i), train.get(j));
332                                        if( distmat[i][j] > biggest ) {
333                                                biggest = distmat[i][j];
334                                        }
335                                }
336                        }
337                        //Console.traceln(Level.INFO, String.format("biggest distances: "+ biggest));
338                       
339                        // 4. run fastmap for 2 dimensions on the distance matrix
340                        Fastmap FMAP = new Fastmap(2);
341                        FMAP.setDistmat(distmat);
342                        FMAP.calculate();
343                       
344                        cpivotindices = FMAP.getPivots();
345                       
346                        double[][] X = FMAP.getX();
347                        distmat = new double[0][0];
348                        System.gc();
349                       
350                        // quadtree payload generation
351                        ArrayList<QuadTreePayload<Instance>> qtp = new ArrayList<QuadTreePayload<Instance>>();
352                   
353                        // we need these for the sizes of the quadrants
354                        double[] big = {0,0};
355                        double[] small = {Double.MAX_VALUE,Double.MAX_VALUE};
356                       
357                        // set quadtree payload values and get max and min x and y values for size
358                    for( int i=0; i<X.length; i++ ){
359                        if(X[i][0] >= big[0]) {
360                                big[0] = X[i][0];
361                        }
362                        if(X[i][1] >= big[1]) {
363                                big[1] = X[i][1];
364                        }
365                        if(X[i][0] <= small[0]) {
366                                small[0] = X[i][0];
367                        }
368                        if(X[i][1] <= small[1]) {
369                                small[1] = X[i][1];
370                        }
371                        QuadTreePayload<Instance> tmp = new QuadTreePayload<Instance>(X[i][0], X[i][1], train2.get(i));
372                        qtp.add(tmp);
373                    }
374                   
375                    //Console.traceln(Level.INFO, String.format("size for cluster ("+small[0]+","+small[1]+") - ("+big[0]+","+big[1]+")"));
376                   
377                    // 5. generate quadtree
378                    QuadTree TREE = new QuadTree(null, qtp);
379                    QuadTree.size = train.size();
380                    QuadTree.alpha = Math.sqrt(train.size());
381                    QuadTree.ccluster = new ArrayList<ArrayList<QuadTreePayload<Instance>>>();
382                    QuadTree.csize = new HashMap<Integer, ArrayList<Double[][]>>();
383                   
384                    //Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ QuadTree.size + " size, Alpha: "+ QuadTree.alpha+ ""));
385                   
386                    // set the size and then split the tree recursively at the median value for x, y
387                    TREE.setSize(new double[] {small[0], big[0]}, new double[] {small[1], big[1]});
388                   
389                    // recursive split und grid clustering eher static
390                    TREE.recursiveSplit(TREE);
391                   
392                    // generate list of nodes sorted by density (childs only)
393                    ArrayList<QuadTree> l = new ArrayList<QuadTree>(TREE.getList(TREE));
394                   
395                    // recursive grid clustering (tree pruning), the values are stored in ccluster
396                    TREE.gridClustering(l);
397                   
398                    // wir iterieren durch die cluster und sammeln uns die instanzen daraus
399                    //ctraindata.clear();
400                    for( int i=0; i < QuadTree.ccluster.size(); i++ ) {
401                        ArrayList<QuadTreePayload<Instance>> current = QuadTree.ccluster.get(i);
402                       
403                        // i is the clusternumber
404                        // we only allow clusters with Instances > ALPHA, other clusters are not considered!
405                        //if(current.size() > QuadTree.alpha) {
406                        if( current.size() > 4 ) {
407                                for( int j=0; j < current.size(); j++ ) {
408                                        if( !ctraindata.containsKey(i) ) {
409                                                ctraindata.put(i, new Instances(train2));
410                                                ctraindata.get(i).delete();
411                                        }
412                                        ctraindata.get(i).add(current.get(j).getInst());
413                                }
414                        }else{
415                                Console.traceln(Level.INFO, String.format("drop cluster, only: " + current.size() + " instances"));
416                        }
417                    }
418                       
419                        // here we keep things we need later on
420                        // QuadTree sizes for later use (matching new instances)
421                        this.csize = new HashMap<Integer, ArrayList<Double[][]>>(QuadTree.csize);
422               
423                        // pivot elements
424                        //this.cpivots.clear();
425                        for( int i=0; i < FMAP.PA[0].length; i++ ) {
426                                this.cpivots.put(FMAP.PA[0][i], (Instance)train.get(FMAP.PA[0][i]).copy());
427                        }
428                        for( int j=0; j < FMAP.PA[0].length; j++ ) {
429                                this.cpivots.put(FMAP.PA[1][j], (Instance)train.get(FMAP.PA[1][j]).copy());
430                        }
431                       
432                       
433                        /* debug output
434                        int pnumber;
435                        Iterator<Integer> pivotnumber = cpivots.keySet().iterator();
436                        while ( pivotnumber.hasNext() ) {
437                                pnumber = pivotnumber.next();
438                                Console.traceln(Level.INFO, String.format("pivot: "+pnumber+ " inst: "+cpivots.get(pnumber)));
439                        }
440                        */
441                       
442                    // train one classifier per cluster, we get the cluster number from the traindata
443                    int cnumber;
444                        Iterator<Integer> clusternumber = ctraindata.keySet().iterator();
445                        //cclassifier.clear();
446                       
447                        //int traindata_count = 0;
448                        while ( clusternumber.hasNext() ) {
449                                cnumber = clusternumber.next();
450                                cclassifier.put(cnumber,setupClassifier());  // this is the classifier used for the cluster
451                                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber));
452                                //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber));
453                                //traindata_count += ctraindata.get(cnumber).size();
454                                //Console.traceln(Level.INFO, String.format("building classifier in cluster "+cnumber +"  with "+ ctraindata.get(cnumber).size() +" traindata instances"));
455                        }
456                       
457                        // add all traindata
458                        //Console.traceln(Level.INFO, String.format("traindata in all clusters: " + traindata_count));
459                }
460        }
461       
462
463        /**
464         * Payload for the QuadTree.
465         * x and y are the calculated Fastmap values.
466         * T is a weka instance.
467         */
468        public class QuadTreePayload<T> {
469
470                public double x;
471                public double y;
472                private T inst;
473               
474                public QuadTreePayload(double x, double y, T value) {
475                        this.x = x;
476                        this.y = y;
477                        this.inst = value;
478                }
479               
480                public T getInst() {
481                        return this.inst;
482                }
483        }
484       
485       
486        /**
487         * Fastmap implementation
488         *
489         * Faloutsos, C., & Lin, K. I. (1995).
490         * FastMap: A fast algorithm for indexing, data-mining and visualization of traditional and multimedia datasets
491         * (Vol. 24, No. 2, pp. 163-174). ACM.
492         */
493        public class Fastmap {
494               
495                /*N x k Array, at the end, the i-th row will be the image of the i-th object*/
496                private double[][] X;
497               
498                /*2 x k pivot Array one pair per recursive call*/
499                private int[][] PA;
500               
501                /*Objects we got (distance matrix)*/
502                private double[][] O;
503               
504                /*column of X currently updated (also the dimension)*/
505                private int col = 0;
506               
507                /*number of dimensions we want*/
508                private int target_dims = 0;
509               
510                // if we already have the pivot elements
511                private boolean pivot_set = false;
512               
513
514                public Fastmap(int k) {
515                        this.target_dims = k;
516                }
517               
518                /**
519                 * Sets the distance matrix
520                 * and params that depend on this
521                 * @param O
522                 */
523                public void setDistmat(double[][] O) {
524                        this.O = O;
525                        int N = O.length;
526                        this.X = new double[N][this.target_dims];
527                        this.PA = new int[2][this.target_dims];
528                }
529               
530                /**
531                 * Set pivot elements, we need that to classify instances
532                 * after the calculation is complete (because we then want to reuse
533                 * only the pivot elements).
534                 *
535                 * @param pi
536                 */
537                public void setPivots(int[][] pi) {
538                        this.pivot_set = true;
539                        this.PA = pi;
540                }
541               
542                /**
543                 * Return the pivot elements that were chosen during the calculation
544                 *
545                 * @return
546                 */
547                public int[][] getPivots() {
548                        return this.PA;
549                }
550               
551                /**
552                 * The distance function for euclidean distance
553                 *
554                 * Acts according to equation 4 of the fastmap paper
555                 * 
556                 * @param x x index of x image (if k==0 x object)
557                 * @param y y index of y image (if k==0 y object)
558                 * @param kdimensionality
559                 * @return distance
560                 */
561                private double dist(int x, int y, int k) {
562                       
563                        // basis is object distance, we get this from our distance matrix
564                        double tmp = this.O[x][y] * this.O[x][y];
565                       
566                        // decrease by projections
567                        for( int i=0; i < k; i++ ) {
568                                double tmp2 = (this.X[x][i] - this.X[y][i]);
569                                tmp -= tmp2 * tmp2;
570                        }
571                       
572                        return Math.abs(tmp);
573                }
574
575                /**
576                 * Find the object farthest from the given index
577                 * This method is a helper Method for findDistandObjects
578                 *
579                 * @param index of the object
580                 * @return index of the farthest object from the given index
581                 */
582                private int findFarthest(int index) {
583                        double furthest = Double.MIN_VALUE;
584                        int ret = 0;
585                       
586                        for( int i=0; i < O.length; i++ ) {
587                                double dist = this.dist(i, index, this.col);
588                                if( i != index && dist > furthest ) {
589                                        furthest = dist;
590                                        ret = i;
591                                }
592                        }
593                        return ret;
594                }
595               
596                /**
597                 * Finds the pivot objects
598                 *
599                 * This method is basically algorithm 1 of the fastmap paper.
600                 *
601                 * @return 2 indexes of the choosen pivot objects
602                 */
603                private int[] findDistantObjects() {
604                        // 1. choose object randomly
605                        Random r = new Random();
606                        int obj = r.nextInt(this.O.length);
607                       
608                        // 2. find farthest object from randomly chosen object
609                        int idx1 = this.findFarthest(obj);
610                       
611                        // 3. find farthest object from previously farthest object
612                        int idx2 = this.findFarthest(idx1);
613
614                        return new int[] {idx1, idx2};
615                }
616       
617                /**
618                 * Calculates the new k-vector values (projections)
619                 *
620                 * This is basically algorithm 2 of the fastmap paper.
621                 * We just added the possibility to pre-set the pivot elements because
622                 * we need to classify single instances after the computation is already done.
623                 *
624                 * @param dims dimensionality
625                 */
626                public void calculate() {
627                       
628                        for( int k=0; k < this.target_dims; k++ ) {
629                                // 2) choose pivot objects
630                                if ( !this.pivot_set ) {
631                                        int[] pivots = this.findDistantObjects();
632               
633                                        // 3) record ids of pivot objects
634                                        this.PA[0][this.col] = pivots[0];
635                                        this.PA[1][this.col] = pivots[1];
636                                }
637                               
638                                // 4) inter object distances are zero (this.X is initialized with 0 so we just continue)
639                                if( this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col) == 0 ) {
640                                        continue;
641                                }
642                               
643                                // 5) project the objects on the line between the pivots
644                                double dxy = this.dist(this.PA[0][this.col], this.PA[1][this.col], this.col);
645                                for( int i=0; i < this.O.length; i++ ) {
646                                       
647                                        double dix = this.dist(i, this.PA[0][this.col], this.col);
648                                        double diy = this.dist(i, this.PA[1][this.col], this.col);
649                                       
650                                        double tmp = (dix + dxy - diy) / (2 * Math.sqrt(dxy));
651                                       
652                                        // save the projection
653                                        this.X[i][this.col] = tmp;
654                                }
655                               
656                                this.col += 1;
657                        }
658                }
659               
660                /**
661                 * returns the result matrix of the projections
662                 *
663                 * @return calculated result
664                 */
665                public double[][] getX() {
666                        return this.X;
667                }
668        }
669}
Note: See TracBrowser for help on using the repository browser.