Ignore:
Timestamp:
09/24/15 10:59:05 (9 years ago)
Author:
sherbold
Message:
  • formatted code and added copyrights
File:
1 edited

Legend:

Unmodified
Added
Removed
  • trunk/CrossPare/src/de/ugoe/cs/cpdp/training/WekaBaseTraining.java

    r25 r41  
     1// Copyright 2015 Georg-August-Universität Göttingen, Germany 
     2// 
     3//   Licensed under the Apache License, Version 2.0 (the "License"); 
     4//   you may not use this file except in compliance with the License. 
     5//   You may obtain a copy of the License at 
     6// 
     7//       http://www.apache.org/licenses/LICENSE-2.0 
     8// 
     9//   Unless required by applicable law or agreed to in writing, software 
     10//   distributed under the License is distributed on an "AS IS" BASIS, 
     11//   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 
     12//   See the License for the specific language governing permissions and 
     13//   limitations under the License. 
     14 
    115package de.ugoe.cs.cpdp.training; 
    216 
     
    1529 * Allows specification of the Weka classifier and its params in the XML experiment configuration. 
    1630 *  
    17  * Important conventions of the XML format:  
    18  * Cross Validation params always come last and are prepended with -CVPARAM 
    19  * Example: <trainer name="WekaTraining" param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/> 
     31 * Important conventions of the XML format: Cross Validation params always come last and are 
     32 * prepended with -CVPARAM Example: <trainer name="WekaTraining" 
     33 * param="RandomForestLocal weka.classifiers.trees.RandomForest -CVPARAM I 5 25 5"/> 
    2034 */ 
    2135public abstract class WekaBaseTraining implements IWekaCompatibleTrainer { 
    22          
    23         protected Classifier classifier = null; 
    24         protected String classifierClassName; 
    25         protected String classifierName; 
    26         protected String[] classifierParams; 
    27          
    28         @Override 
    29         public void setParameter(String parameters) { 
    30                 String[] params = parameters.split(" "); 
    3136 
    32                 // first part of the params is the classifierName (e.g. SMORBF) 
    33                 classifierName = params[0]; 
    34                  
    35                 // the following parameters can be copied from weka! 
    36                  
    37                 // second param is classifierClassName (e.g. weka.classifiers.functions.SMO) 
    38                 classifierClassName = params[1]; 
    39          
    40                 // rest are params to the specified classifier (e.g. -K weka.classifiers.functions.supportVector.RBFKernel) 
    41                 classifierParams = Arrays.copyOfRange(params, 2, params.length); 
    42                  
    43                 classifier = setupClassifier(); 
    44         } 
     37    protected Classifier classifier = null; 
     38    protected String classifierClassName; 
     39    protected String classifierName; 
     40    protected String[] classifierParams; 
    4541 
    46         @Override 
    47         public Classifier getClassifier() { 
    48                 return classifier; 
    49         } 
     42    @Override 
     43    public void setParameter(String parameters) { 
     44        String[] params = parameters.split(" "); 
    5045 
    51         public Classifier setupClassifier() { 
    52                 Classifier cl = null; 
    53                 try{ 
    54                         @SuppressWarnings("rawtypes") 
    55                         Class c = Class.forName(classifierClassName); 
    56                         Classifier obj = (Classifier) c.newInstance(); 
    57                          
    58                         // Filter out -CVPARAM, these are special because they do not belong to the Weka classifier class as parameters 
    59                         String[] param = Arrays.copyOf(classifierParams, classifierParams.length); 
    60                         String[] cvparam = {}; 
    61                         boolean cv = false; 
    62                         for ( int i=0; i < classifierParams.length; i++ ) { 
    63                                 if(classifierParams[i].equals("-CVPARAM")) { 
    64                                         // rest of array are cvparam 
    65                                         cvparam = Arrays.copyOfRange(classifierParams, i+1, classifierParams.length); 
    66                                          
    67                                         // before this we have normal params 
    68                                         param = Arrays.copyOfRange(classifierParams, 0, i); 
    69                                          
    70                                         cv = true; 
    71                                         break; 
    72                                 } 
    73                         } 
    74                          
    75                         // set classifier params 
    76                         ((OptionHandler)obj).setOptions(param); 
    77                         cl = obj; 
    78                          
    79                         // we have cross val params 
    80                         // cant check on cvparam.length here, it may not be initialized                  
    81                         if(cv) { 
    82                                 final CVParameterSelection ps = new CVParameterSelection(); 
    83                                 ps.setClassifier(obj); 
    84                                 ps.setNumFolds(5); 
    85                                 //ps.addCVParameter("I 5 25 5"); 
    86                                 for( int i=1 ; i<cvparam.length/4 ; i++ ) { 
    87                                         ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4*i)).toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", "")); 
    88                                 } 
    89                                  
    90                                 cl = ps; 
    91                         } 
     46        // first part of the params is the classifierName (e.g. SMORBF) 
     47        classifierName = params[0]; 
    9248 
    93                 }catch(ClassNotFoundException e) { 
    94                         Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString())); 
    95                         e.printStackTrace(); 
    96                 } catch (InstantiationException e) { 
    97                         Console.traceln(Level.WARNING, String.format("Instantiation Exception: %s", e.toString())); 
    98                         e.printStackTrace(); 
    99                 } catch (IllegalAccessException e) { 
    100                         Console.traceln(Level.WARNING, String.format("Illegal Access Exception: %s", e.toString())); 
    101                         e.printStackTrace(); 
    102                 } catch (Exception e) { 
    103                         Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString())); 
    104                         e.printStackTrace(); 
    105                 } 
    106                  
    107                 return cl; 
    108         } 
     49        // the following parameters can be copied from weka! 
    10950 
    110         @Override 
    111         public String getName() { 
    112                 return classifierName; 
    113         } 
    114          
     51        // second param is classifierClassName (e.g. weka.classifiers.functions.SMO) 
     52        classifierClassName = params[1]; 
     53 
     54        // rest are params to the specified classifier (e.g. -K 
     55        // weka.classifiers.functions.supportVector.RBFKernel) 
     56        classifierParams = Arrays.copyOfRange(params, 2, params.length); 
     57 
     58        classifier = setupClassifier(); 
     59    } 
     60 
     61    @Override 
     62    public Classifier getClassifier() { 
     63        return classifier; 
     64    } 
     65 
     66    public Classifier setupClassifier() { 
     67        Classifier cl = null; 
     68        try { 
     69            @SuppressWarnings("rawtypes") 
     70            Class c = Class.forName(classifierClassName); 
     71            Classifier obj = (Classifier) c.newInstance(); 
     72 
     73            // Filter out -CVPARAM, these are special because they do not belong to the Weka 
     74            // classifier class as parameters 
     75            String[] param = Arrays.copyOf(classifierParams, classifierParams.length); 
     76            String[] cvparam = { }; 
     77            boolean cv = false; 
     78            for (int i = 0; i < classifierParams.length; i++) { 
     79                if (classifierParams[i].equals("-CVPARAM")) { 
     80                    // rest of array are cvparam 
     81                    cvparam = Arrays.copyOfRange(classifierParams, i + 1, classifierParams.length); 
     82 
     83                    // before this we have normal params 
     84                    param = Arrays.copyOfRange(classifierParams, 0, i); 
     85 
     86                    cv = true; 
     87                    break; 
     88                } 
     89            } 
     90 
     91            // set classifier params 
     92            ((OptionHandler) obj).setOptions(param); 
     93            cl = obj; 
     94 
     95            // we have cross val params 
     96            // cant check on cvparam.length here, it may not be initialized 
     97            if (cv) { 
     98                final CVParameterSelection ps = new CVParameterSelection(); 
     99                ps.setClassifier(obj); 
     100                ps.setNumFolds(5); 
     101                // ps.addCVParameter("I 5 25 5"); 
     102                for (int i = 1; i < cvparam.length / 4; i++) { 
     103                    ps.addCVParameter(Arrays.asList(Arrays.copyOfRange(cvparam, 0, 4 * i)) 
     104                        .toString().replaceAll(", ", " ").replaceAll("^\\[|\\]$", "")); 
     105                } 
     106 
     107                cl = ps; 
     108            } 
     109 
     110        } 
     111        catch (ClassNotFoundException e) { 
     112            Console.traceln(Level.WARNING, String.format("class not found: %s", e.toString())); 
     113            e.printStackTrace(); 
     114        } 
     115        catch (InstantiationException e) { 
     116            Console.traceln(Level.WARNING, 
     117                            String.format("Instantiation Exception: %s", e.toString())); 
     118            e.printStackTrace(); 
     119        } 
     120        catch (IllegalAccessException e) { 
     121            Console.traceln(Level.WARNING, 
     122                            String.format("Illegal Access Exception: %s", e.toString())); 
     123            e.printStackTrace(); 
     124        } 
     125        catch (Exception e) { 
     126            Console.traceln(Level.WARNING, String.format("Exception: %s", e.toString())); 
     127            e.printStackTrace(); 
     128        } 
     129 
     130        return cl; 
     131    } 
     132 
     133    @Override 
     134    public String getName() { 
     135        return classifierName; 
     136    } 
     137 
    115138} 
Note: See TracChangeset for help on using the changeset viewer.