Changeset 12


Ignore:
Timestamp:
08/25/14 13:27:30 (10 years ago)
Author:
atrautsch
Message:

Update der Vorabversion des neuen Trainers.
Experiment kann nun durchgeführt werden.

File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaLocalTraining2.java

    r9 r12  
    1717import weka.classifiers.AbstractClassifier; 
    1818import weka.classifiers.Classifier; 
    19 import weka.clusterers.EM; 
    2019import weka.core.DenseInstance; 
    2120import weka.core.EuclideanDistance; 
     
    2827 * ACHTUNG UNFERTIG 
    2928 * 
     29 *  
     30 * Basically a copy of WekaClusterTraining2 with internal classes for the Fastmap and QuadTree implementations 
    3031 */ 
    3132public class WekaLocalTraining2 extends WekaBaseTraining2 implements ITrainingStrategy { 
    3233         
    3334        private final TraindatasetCluster classifier = new TraindatasetCluster(); 
    34         private final QuadTree q = null; 
    35         private final Fastmap f = null; 
    36          
    37         /*Stopping rule for tree reqursion (Math.sqrt(Instances)*/ 
    38         public static double ALPHA = 3; 
     35         
     36        // we do not need to keep them around 
     37        //private final QuadTree q = null; 
     38        //private final Fastmap f = null; 
     39         
     40        // these values are set later when we have all the information we need 
     41        /*Stopping rule for tree recursion (Math.sqrt(Instances)*/ 
     42        public static double ALPHA = 0; 
    3943        /*Stopping rule for clustering*/ 
    4044        public static double DELTA = 0.5; 
    41         /*size of the complete set (used for density)*/ 
     45        /*size of the complete set (used for density function)*/ 
    4246        public static int SIZE = 0; 
     47         
     48        public static int MIN_INST = 10; 
    4349         
    4450        // cluster 
     
    6975                private static final long serialVersionUID = 1L; 
    7076 
    71                 private EM clusterer = null; 
    72  
    7377                private HashMap<Integer, Classifier> cclassifier = new HashMap<Integer, Classifier>(); 
    7478                private HashMap<Integer, Instances> ctraindata = new HashMap<Integer, Instances>();  
    75                  
    7679                 
    7780                 
     
    100103                } 
    101104                 
    102                  
     105                /** 
     106                 * Because Fastmap saves only the image not the values of the attributes 
     107                 * we can not use it to classify single instances to values 
     108                 *  
     109                 * TODO: mehr erklärung 
     110                 * TODO: class lavel filter raus 
     111                 *  
     112                 * Finde die am nächsten liegende Instanz zur übergebenen 
     113                 * dann bestimme den cluster der instanz und führe dann den  
     114                 * classifier des clusters aus 
     115                 */ 
    103116                @Override 
    104117                public double classifyInstance(Instance instance) { 
     118                         
    105119                        double ret = 0; 
    106120                        try { 
     
    116130                                Instance clusterInstance = createInstance(traindata, instance); 
    117131                                 
    118                                 // 1. classify testdata instance to a cluster number 
    119                                 int cnum = clusterer.clusterInstance(clusterInstance); 
    120                                  
    121                                 // 2. classify testata instance to the classifier 
    122                                 ret = cclassifier.get(cnum).classifyInstance(classInstance); 
     132                                // get distance of this instance to every other instance 
     133                                // if the distance is minimal apply the classifier of the current cluster 
     134                                int cnumber; 
     135                                int min_cluster = -1; 
     136                                double min_distance = 99999999; 
     137                                EuclideanDistance d; 
     138                                Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
     139                                while ( clusternumber.hasNext() ) { 
     140                                        cnumber = clusternumber.next(); 
     141                                         
     142                                        d = new EuclideanDistance(ctraindata.get(cnumber)); 
     143                                        for(int i=0; i < ctraindata.get(cnumber).size(); i++) { 
     144                                                if(d.distance(clusterInstance, ctraindata.get(cnumber).get(i)) <= min_distance) { 
     145                                                        min_cluster = cnumber; 
     146                                                        min_distance = d.distance(clusterInstance, ctraindata.get(cnumber).get(i)); 
     147                                                } 
     148                                        } 
     149                                } 
     150                                 
     151                                // here we have the cluster where an instance has the minimum distance between itself the 
     152                                // instance we want to classify 
     153                                if(min_cluster == -1) { 
     154                                        // this is an error condition 
     155                                        throw new RuntimeException("min_cluster not found"); 
     156                                } 
     157                                 
     158                                // classify the passed instance with the cluster we found 
     159                                ret = cclassifier.get(min_cluster).classifyInstance(classInstance); 
    123160                                 
    124161                        }catch( Exception e ) { 
     
    128165                        return ret; 
    129166                } 
    130  
    131                  
    132167                 
    133168                @Override 
     
    138173                         
    139174                        // 2. remove class attribute for clustering 
    140                         Remove filter = new Remove(); 
    141                         filter.setAttributeIndices("" + (train.classIndex() + 1)); 
    142                         filter.setInputFormat(train); 
    143                         train = Filter.useFilter(train, filter); 
    144                          
    145                          
    146                         // 3. calculate distance matrix 
     175                        //Remove filter = new Remove(); 
     176                        //filter.setAttributeIndices("" + (train.classIndex() + 1)); 
     177                        //filter.setInputFormat(train); 
     178                        //train = Filter.useFilter(train, filter); 
     179                         
     180                         
     181                        // 3. calculate distance matrix (needed for Fastmap because it starts at dimension 1) 
    147182                        EuclideanDistance d = new EuclideanDistance(train); 
    148183                        double[][] dist = new double[train.size()][train.size()]; 
     
    153188                        } 
    154189                         
    155                         // 4. run fastmap for 2 dimensions on distance matrix 
     190                        // 4. run fastmap for 2 dimensions on the distance matrix 
    156191                        Fastmap f = new Fastmap(2, dist); 
    157192                        f.calculate(2); 
     
    163198                        // die max und min brauchen wir für die größenangaben der sektoren 
    164199                        double[] big = {0,0}; 
    165                         double[] small = {-99999,-99999}; 
    166                          
    167                         // set quadtree payload values 
     200                        double[] small = {9999999,99999999}; 
     201                         
     202                        // set quadtree payload values and get max and min x and y values for size 
    168203                    for(int i=0; i<X.length; i++){ 
    169204                        if(X[i][0] >= big[0]) { 
     
    188223                    SIZE = train.size(); 
    189224                     
    190                     // split recursively 
     225                    Console.traceln(Level.INFO, String.format("Generate QuadTree with "+ SIZE + " size, Alpha: "+ ALPHA+ "")); 
     226                     
     227                    // set the size and then split the tree recursively at the median value for x, y 
    191228                    q.setSize(new double[] {small[0], big[0]}, new double[] {small[1], big[1]}); 
    192229                    q.recursiveSplit(q); 
    193230                     
    194                     // generate list of nodes sorted by density 
     231                    // generate list of nodes sorted by density (childs only) 
    195232                    ArrayList<QuadTree> l = new ArrayList<QuadTree>(q.getList(q)); 
    196233                     
    197                     // grid clustering recursive (tree pruning) 
     234                    // recursive grid clustering (tree pruning), the values are stored in cluster! 
    198235                    q.gridClustering(l); 
    199236                     
     237                    // after grid clustering we need to remove the clusters with < 2 * ALPHA instances 
    200238                     
    201239                    // hier müssten wir sowas haben wie welche instanz in welchem cluster ist 
    202                     // oder wir iterieren durch die cluster und sammeln uns die instaznen daraus 
     240                    // oder wir iterieren durch die cluster und sammeln uns die instanzen daraus 
    203241                    for(int i=0; i < cluster.size(); i++) { 
    204242                        ArrayList<QuadTreePayload<Instance>> current = cluster.get(i); 
    205                         for(int j=0; j < current.size(); j++ ) { 
    206                                  
     243                         
     244                        // i is the clusternumber 
     245                        // we only allow clusters with Instances > ALPHA 
     246                        if(current.size() > ALPHA) { 
     247                                for(int j=0; j < current.size(); j++ ) { 
     248                                        if(!ctraindata.containsKey(i)) { 
     249                                                ctraindata.put(i, new Instances(traindata)); 
     250                                                ctraindata.get(i).delete(); 
     251                                        } 
     252                                        ctraindata.get(i).add(current.get(j).getInst()); 
     253                                } 
    207254                        } 
    208255                    } 
    209                      
    210                         Instances ctrain = new Instances(train); 
    211                          
    212                         // get traindata per cluster 
    213                         int cnumber; 
    214                         for ( int j=0; j < ctrain.numInstances(); j++ ) { 
    215                                 // get the cluster number from the attributes, subract 1 because if we clusterInstance we get 0-n, and this is 1-n 
    216                                 //cnumber = Integer.parseInt(ctrain.get(j).stringValue(ctrain.get(j).numAttributes()-1).replace("cluster", "")) - 1; 
    217                                  
    218                                 cnumber = clusterer.clusterInstance(ctrain.get(j)); 
    219                                 // add training data to list of instances for this cluster number 
    220                                 if ( !ctraindata.containsKey(cnumber) ) { 
    221                                         ctraindata.put(cnumber, new Instances(traindata)); 
    222                                         ctraindata.get(cnumber).delete(); 
    223                                 } 
    224                                 ctraindata.get(cnumber).add(traindata.get(j)); 
    225                         } 
    226                          
    227                         // train one classifier per cluster, we get the clusternumber from the traindata 
     256 
     257                    // train one classifier per cluster, we get the clusternumber from the traindata 
     258                    int cnumber; 
    228259                        Iterator<Integer> clusternumber = ctraindata.keySet().iterator(); 
    229260                        while ( clusternumber.hasNext() ) { 
     
    231262                                cclassifier.put(cnumber,setupClassifier()); 
    232263                                cclassifier.get(cnumber).buildClassifier(ctraindata.get(cnumber)); 
    233                                  
    234264                                //Console.traceln(Level.INFO, String.format("classifier in cluster "+cnumber)); 
     265                                //Console.traceln(Level.INFO, String.format("" + ctraindata.get(cnumber).size() + " instances in cluster "+cnumber)); 
    235266                        } 
    236267                } 
     
    254285                } 
    255286                 
    256                 public T getinst() { 
     287                public T getInst() { 
    257288                        return this.inst; 
    258289                } 
     
    261292        /** 
    262293         * Fastmap implementation 
     294         *  
     295         * TODO: only one place to pass dimension! 
    263296         *  
    264297         * Faloutsos, C., & Lin, K. I. (1995).  
     
    289322                } 
    290323                 
    291                  
    292                 /*recursive function ALT*/ 
    293                 private double dist2(int x, int y, int k) { 
    294                         // basisfall 
    295                         if(k == 0) { 
    296                                 return Math.pow(this.O[x][y], 2); 
    297                         } 
    298                          
    299                         double dist_rec = Math.pow(this.dist(x, y, k-1), 2); 
    300                         double dist_norm = Math.pow(Math.abs(this.X[x][k] - this.X[y][k]), 2);  
    301                          
    302                         return Math.sqrt(Math.abs(dist_rec - dist_norm)); 
    303                         //return Math.abs(dist_rec - dist_norm); 
    304                 } 
    305                  
    306324                /** 
    307325                 * The distance function for eculidean distance 
    308326                 *  
    309                  * Acts according toequation 4 of the fastmap paper 
     327                 * Acts according to equation 4 of the fastmap paper 
    310328                 *   
    311329                 * @param x x index of x image (if k==0 x object) 
     
    316334                private double dist(int x, int y, int k) { 
    317335                         
    318                         // objectabstand ist basis 
    319          
    320                         // das hier wäre ein abstand zwischen 2 weka instanzen, z.B. euclidischer abstand zwischen den beiden vektoren 
     336                        // basis is object distance, we get this from our distance matrix 
     337                        // alternatively we could provide a distance function that takes 2 vectors 
    321338                        double tmp = this.O[x][y] * this.O[x][y];  
    322                          
    323339                         
    324340                        // decrease by projections 
     
    333349                 
    334350                /** 
    335                  * Find the object furthest from the given index 
     351                 * Find the object farthest from the given index 
    336352                 * This method is a helper Method for findDistandObjects 
    337353                 *  
    338354                 * @param index of the object  
    339                  * @return index of the furthest object from the given index 
    340                  */ 
    341                 private int findFurthest(int index) { 
     355                 * @return index of the farthest object from the given index 
     356                 */ 
     357                private int findFarthest(int index) { 
    342358                        double furthest = -1000000; 
    343359                        int ret = 0; 
     
    353369                } 
    354370                 
    355          
    356371                /** 
    357372                 * Finds the pivot objects  
     
    366381                        int obj = r.nextInt(this.O.length); 
    367382                         
    368                         // 2. find furthest object from randomly chooen object 
    369                         int idx1 = this.findFurthest(obj); 
    370                          
    371                         // 3. find furthest object from previously furthest object 
    372                         int idx2 = this.findFurthest(idx1); 
    373                          
    374                         int[] ret = {idx1, idx2}; 
    375                         return ret; 
    376                 } 
    377          
    378                  
    379                 /** 
    380                  * Gives image of object (projection on the line between px, py) 
    381                  *  
    382                  * @param index of the object to project 
    383                  * @param px pivot 1 
    384                  * @param py pivot 2 
    385                  * @return projection 
    386                  */ 
    387                 private double project(int index, int px, int py) { 
    388                          
    389                         double dix = this.dist(index, px, this.col); 
    390                         double diy = this.dist(index, py, this.col); 
    391                         double dxy = this.dist(px, py, this.col); 
    392                          
    393                         return (dix + dxy - diy) / 2 * Math.sqrt(dxy); 
    394                 } 
    395          
    396          
    397                 /*recursive function ALT, geht auch sequentiell*/ 
    398                 public void calculate2(int k) { 
    399                          
    400                         // 1) basisfall 
    401                         if(k <= 0) { 
    402                                 return; 
    403                         } 
    404                          
    405                         // 2) choose pivot objects 
    406                         int[] pivots = this.findDistantObjects(); 
    407                          
    408                         // 3) record ids of pivot objects  
    409                         this.PA[0][this.col] = pivots[0]; 
    410                         this.PA[1][this.col] = pivots[1]; 
    411                          
    412                         System.out.println("found pivots with index: " + pivots[0] + ","+ pivots[1]); 
    413                          
    414                         // 4) inter object distances are zero (this.X is initialized with 0 so we just return) 
    415                         if(this.dist(pivots[0], pivots[1], this.col) == 0) { 
    416                                 return; 
    417                         } 
    418                          
    419                         double dxy = this.dist(pivots[0], pivots[1], this.col); 
    420                          
    421                         if(dxy == 0) { 
    422                                 return; 
    423                         } 
    424                          
    425                         // 5) project the objects on the line between the pivots 
    426                         for(int i=0; i < this.O.length; i++) { 
    427                                  
    428                                 double dix = this.dist(i, pivots[0], this.col); 
    429                                 double diy = this.dist(i, pivots[1], this.col); 
    430          
    431                                 this.X[i][this.col] = (dix + dxy - diy) / 2 * Math.sqrt(dxy); 
    432                                  
    433                                 //this.X[i][this.col] = this.project(i, pivots[0], pivots[1]); 
    434                         } 
    435                          
    436                         this.col += 1; 
    437                          
    438                         // 6) recurse 
    439                         this.calculate2(k-1); 
    440                 } 
    441                  
    442                  
    443                 // test funktion, reproduziert ergebnisse aus dem technical report von fastmap 
    444                 public int[] findDistantObjects2() { 
    445                         int[] ret = {0,0}; 
    446                         if(this.col == 0) { 
    447                                 ret = new int[] {0,3}; 
    448                         } 
    449                         if(this.col == 1) { 
    450                                 ret = new int[] {4,1}; 
    451                         } 
    452                         if(this.col == 2) { 
    453                                 ret = new int[] {2,4}; 
    454                         } 
    455                          
    456                         return ret; 
    457                 } 
    458                  
    459                  
     383                        // 2. find farthest object from randomly chosen object 
     384                        int idx1 = this.findFarthest(obj); 
     385                         
     386                        // 3. find farthest object from previously farthest object 
     387                        int idx2 = this.findFarthest(idx1); 
     388 
     389                        return new int[] {idx1, idx2}; 
     390                } 
     391         
    460392                /** 
    461393                 * Calculates the new k-vector values 
     
    488420                                        double tmp = (dix + dxy - diy) / 2 * Math.sqrt(dxy); 
    489421                                         
    490                                         this.X[i][this.col] = tmp; // / 10000;  
    491                                         //this.X[i][this.col] = this.project(i, pivots[0], pivots[1]); 
     422                                        this.X[i][this.col] = tmp; 
    492423                                } 
    493424                                 
     
    495426                        } 
    496427                } 
    497                  
    498428                 
    499429                /** 
     
    506436        } 
    507437 
    508  
     438        /** 
     439         * QuadTree implementation 
     440         *  
     441         * QuadTree gets a list of instances and then recursively split them into 4 childs 
     442         * For this it uses the median of the 2 values x,y 
     443         */ 
    509444        public class QuadTree { 
    510445                 
     
    521456                private ArrayList<QuadTree> l = new ArrayList<QuadTree>(); 
    522457                 
    523  
     458                // level only used for debugging 
    524459                public int level = 0; 
    525460                 
     
    533468                // evtl. statt ArrayList eigene QuadTreePayloadlist 
    534469                private ArrayList<QuadTreePayload<Instance>> payload; 
    535                  
    536470                 
    537471                public QuadTree(QuadTree parent, ArrayList<QuadTreePayload<Instance>> payload) { 
     
    551485                } 
    552486                 
    553                  
    554487                /** 
    555488                 * Returns the payload, used for clustering 
     
    567500                 * density = number of instances / global size (all instances) 
    568501                 *  
    569                  * @return 
     502                 * @return density 
    570503                 */ 
    571504                public double getDensity() { 
     
    586519                 
    587520                /** 
    588                  * Todo: dry, median ist immer dasselbe 
     521                 * Todo: DRY, median ist immer dasselbe 
    589522                 *   
    590523                 * @return median for x 
     
    663596                 
    664597                /** 
    665                  * Calculate median values of payload for x, y and split into sectors 
     598                 * Calculate median values of payload for x, y and split into 4 sectors 
    666599                 *  
    667600                 * @return Array of QuadTree nodes (4 childs) 
     
    673606                        double medy = this.getMedianForY(); 
    674607                         
    675                          
     608                        // Payload lists for each child 
    676609                        ArrayList<QuadTreePayload<Instance>> nw = new ArrayList<QuadTreePayload<Instance>>(); 
    677610                        ArrayList<QuadTreePayload<Instance>> sw = new ArrayList<QuadTreePayload<Instance>>(); 
     
    710643                        // if we assign one child a payload equal to our own (see problem above) 
    711644                        // we throw an exceptions which stops the recursion on this node 
     645                        // second error is minimum number of instances 
     646                        //Console.traceln(Level.INFO, String.format("NW: "+ nw.size() + " SW: " + sw.size() + " NE: " + ne.size() + " SE: " + se.size())); 
    712647                        if(nw.equals(this.payload)) { 
    713648                                throw new Exception("payload equal"); 
    714649                        } 
    715650                        if(sw.equals(this.payload)) { 
    716                                 throw new Exception("ayload equal"); 
     651                                throw new Exception("payload equal"); 
    717652                        } 
    718653                        if(ne.equals(this.payload)) { 
     
    737672                        this.child_se = new QuadTree(this, se); 
    738673                        this.child_se.setSize(new double[] {medx, this.x[1]}, new double[] {this.y[0], medy}); 
    739                         this.child_se.level = this.level + 1; 
     674                        this.child_se.level = this.level + 1;    
    740675                         
    741676                        this.payload = null; 
     
    844779                } 
    845780                 
    846                  
    847                  
    848781                /** 
    849782                 * Perform Pruning and clustering of the quadtree 
     
    851784                 * 1) get list of leaf quadrants 
    852785                 * 2) sort by their density 
    853                  * 3) merge similar densities to new leaf quadrant 
    854                  * @param q QuadTree 
     786                 * 3) set stop_rule to 0.5 * highest Density in the list 
     787                 * 4) merge all nodes with a density > stop_rule to the new cluster and remove all from list 
     788                 * 5) repeat 
     789                 *  
     790                 * @param q List of QuadTree (children only) 
    855791                 */ 
    856792                public void gridClustering(ArrayList<QuadTree> list) { 
     
    885821                    //System.out.println("removing "+biggest.getDensity() + " from list"); 
    886822                     
    887                         // while items in list 
     823                        // check the items for their density 
    888824                    for(int i=list.size()-1; i >= 0; i--) { 
    889825                        current = list.get(i); 
    890826                         
    891827                                // 2. find neighbours with correct density 
    892                         // if density > 0.5 * DELTA and is_neighbour add to cluster 
     828                        // if density > stop_rule and is_neighbour add to cluster and remove from list 
    893829                        if(current.getDensity() > stop_rule && !current.equals(biggest) && current.isNeighbour(biggest)) { 
    894830                                //System.out.println("adding " + current.getDensity() + " to cluster"); 
     
    923859                 
    924860                /** 
    925                  *  
    926                  * @param q 
    927                  * @return 
     861                 * Helper Method to get a sortet list (by density) for all 
     862                 * children 
     863                 *  
     864                 * @param q QuadTree 
     865                 * @return Sorted ArrayList of quadtrees 
    928866                 */ 
    929867                public ArrayList<QuadTree> getList(QuadTree q) { 
Note: See TracChangeset for help on using the changeset viewer.