Search code examples
attributesclassificationweka

How to get the attribute involved into the classification of an ADTree


My problem is the following :

I'm performing classification using weka's ADTree. I build a classifier on a dataset that have over 1700 attributes. The resulting ADTree only use a very little subset of them to classify the instances (near 10 attribute are used).

My question is, as the attribute computation for my instances is time consuming, could I retrieve the attribute's identifiers used by the ADTree? I aim to compute only the relevant attribute letting the other to a default value in order to avoid very long and useless computing.

Thanks in advance.


Solution

  • Finally find a solution. Had to extend the ADTree to achieve my goal

    public class ExtendedADTree extends ADTree {
    
        private static final long   serialVersionUID    = 1L;
    
        /**
         * @param dataset
         *        The dataset to use to retrieve attributes labels.
         *
         * @return Returns the list of label used to predict the class of an instance knowing the
         *         dataset used to compute the tree.
         */
    
        public List<String> getPredictionNodeLabels(final Instances dataset) {
            final List<String> result = new LinkedList<>();
    
            // Initialize the list of splitter node to explore.
            final List<Splitter> nodesToExplore = new LinkedList<>();
            @SuppressWarnings("unchecked")
            final Enumeration<Splitter> rootChildrens = this.m_root.children();
            Collections.list(rootChildrens).forEach(child -> nodesToExplore.add(child));
    
            while (!(nodesToExplore.isEmpty())) {
                // while there is node to explore get the splitter childrens of the current node and add
                // them to the queue
                final Splitter currentNode = nodesToExplore.remove(0);
    
                // add the label of the splitter node to the result
                result.add(currentNode.attributeString(dataset));
    
                // add the childrens to the nodesToexplore list.
                for (int branch_number = 0; branch_number < currentNode.getNumOfBranches(); branch_number++) {
                    final PredictionNode child = currentNode.getChildForBranch(branch_number);
                    final Enumeration<Splitter> childChildren = child.children();
                    Collections.list(childChildren).forEach(childOfChild -> nodesToExplore.add(childOfChild));
                }
            }
    
            return result;
        }
    }
    

    This method return the list of the attributes' label of decision nodes.