Search code examples
rwekadecision-treerweka

Use decision rules to split other data


I am looking for an elegant solution to use the decision rules created in one dataset (for a example your training set) to split the data of another dataset (e.g test data) according to these rules.

Look at this example:

# Load PimaIndiansDiabetes dataset from mlbench package
library("mlbench")
data("PimaIndiansDiabetes")
## Split in training and test (2/3 - 1/3)
idtrain <- c(sample(1:768,512))
PimaTrain <-PimaIndiansDiabetes[idtrain,]
Pimatest <-PimaIndiansDiabetes[-idtrain,]

m1 <- RWeka::J48(as.factor(as.character(PimaTrain$diabetes)) ~ .,
                 data = PimaTrain[,-c(9)],
                 control = RWeka::Weka_control(M = 10, C= 0.25))

Which gives following output:

> m1
J48 pruned tree
------------------

glucose <= 154
|   age <= 28
|   |   glucose <= 118: neg (157.0/11.0)
|   |   glucose > 118
|   |   |   pressure <= 52: pos (10.0/3.0)
|   |   |   pressure > 52: neg (54.0/12.0)
|   age > 28
|   |   glucose <= 103: neg (54.0/10.0)
|   |   glucose > 103
|   |   |   mass <= 41.3: neg (129.0/55.0)
|   |   |   mass > 41.3: pos (12.0/1.0)
glucose > 154: pos (96.0/19.0)

Number of Leaves  :     7

Size of the tree :  13

Based on these rules you will have 7 groups (or leaves). What I am looking for is to apply these rules (so not re-training a decision tree) on the test data Pimatest so that actually every datapoint can be appointed to one of the 7 groups indicated with a new variable group.

the output would look like this:

head(Pimatest)
   pregnant glucose pressure triceps insulin mass pedigree age diabetes group
3         8     183       64       0       0 23.3    0.672  32      pos     7
4         1      89       66      23      94 28.1    0.167  21      neg     1
6         5     116       74       0       0 25.6    0.201  30      neg     5
7         3      78       50      32      88 31.0    0.248  26      pos     1
8        10     115        0       0       0 35.3    0.134  29      neg     5
11        4     110       92       0       0 37.6    0.191  30      neg     5

I currently have a working solution which is coded really bad so that's why I am looking for an elegant solution for this problem.


Solution

  • As I understand it, you want to be able to tie each point to the set of rules that classify that point. You can get there by converting the J48 tree to a party tree and using tools from the partykit package.

    Because you did not set the seed for the random number generator, we cannot get exactly the same test/training split that you got. I will set the seed to make my example reproducible, but even though I use your code, my tree will be slightly different than yours.

    Reproducible example (mostly your code)

    library(RWeka)
    library("mlbench")
    data("PimaIndiansDiabetes")
    
    ## Split in training and test (2/3 - 1/3)
    set.seed(1234)
    idtrain <- c(sample(1:768,512))
    PimaTrain <-PimaIndiansDiabetes[idtrain,]
    Pimatest <-PimaIndiansDiabetes[-idtrain,]
    
    m1 <- RWeka::J48(as.factor(as.character(PimaTrain$diabetes)) ~ .,
                     data = PimaTrain[,-c(9)],
                     control = RWeka::Weka_control(M = 10, C= 0.25))
    m1
    J48 pruned tree
    ------------------
    glucose <= 122
    |   mass <= 26.8: neg (85.0/1.0)
    |   mass > 26.8
    |   |   pregnant <= 4: neg (137.0/19.0)
    |   |   pregnant > 4
    |   |   |   glucose <= 106: neg (44.0/10.0)
    |   |   |   glucose > 106: pos (24.0/6.0)
    glucose > 122
    |   glucose <= 157
    |   |   age <= 31
    |   |   |   age <= 24: neg (30.0/5.0)
    |   |   |   age > 24
    |   |   |   |   pressure <= 72: pos (16.0/5.0)
    |   |   |   |   pressure > 72: neg (22.0/5.0)
    |   |   age > 31: pos (78.0/27.0)
    |   glucose > 157: pos (76.0/13.0)
    
    Number of Leaves  :     9
    Size of the tree :      17
    

    My tree had 9 leaves instead of your 7. This is due to the different instances chosen for the training set. Now we are ready to get the rules.

    library(partykit)
    Pm1 = as.party(m1)
    Pm1
    Fitted party:
    [1] root
    |   [2] glucose <= 122
    |   |   [3] mass <= 26.8: neg (n = 85, err = 1.2%)
    |   |   [4] mass > 26.8
    |   |   |   [5] pregnant <= 4: neg (n = 137, err = 13.9%)
    |   |   |   [6] pregnant > 4
    |   |   |   |   [7] glucose <= 106: neg (n = 44, err = 22.7%)
    |   |   |   |   [8] glucose > 106: pos (n = 24, err = 25.0%)
    |   [9] glucose > 122
    |   |   [10] glucose <= 157
    |   |   |   [11] age <= 31
    |   |   |   |   [12] age <= 24: neg (n = 30, err = 16.7%)
    |   |   |   |   [13] age > 24
    |   |   |   |   |   [14] pressure <= 72: pos (n = 16, err = 31.2%)
    |   |   |   |   |   [15] pressure > 72: neg (n = 22, err = 22.7%)
    |   |   |   [16] age > 31: pos (n = 78, err = 34.6%)
    |   |   [17] glucose > 157: pos (n = 76, err = 17.1%)
    
    Number of inner nodes:    8
    Number of terminal nodes: 9
    

    This is the same tree as before, but has the advantage that the nodes are labeled. We can also get the rules written out for each leaf.

    Pm1_rules = partykit:::.list.rules.party(Pm1)
    Pm1_rules
                                                                           3 
                                             "glucose <= 122 & mass <= 26.8" 
                                                                           5 
                              "glucose <= 122 & mass > 26.8 & pregnant <= 4" 
                                                                           7 
              "glucose <= 122 & mass > 26.8 & pregnant > 4 & glucose <= 106" 
                                                                           8 
               "glucose <= 122 & mass > 26.8 & pregnant > 4 & glucose > 106" 
                                                                          12 
                    "glucose > 122 & glucose <= 157 & age <= 31 & age <= 24" 
                                                                          14 
    "glucose > 122 & glucose <= 157 & age <= 31 & age > 24 & pressure <= 72" 
                                                                          15 
     "glucose > 122 & glucose <= 157 & age <= 31 & age > 24 & pressure > 72" 
                                                                          16 
                                 "glucose > 122 & glucose <= 157 & age > 31" 
                                                                          17 
                                             "glucose > 122 & glucose > 157" 
    

    The decisions are written out as rules. The names of the rulesets are the numbers of the leaf nodes. To get the rules used for a test point, you just need to know which leaf node it ends up at. But the predict method for party object will give you that.

    TestPred = predict(Pm1, newdata=Pimatest, type="node")
    TestPred
      3   4   5   6   9  12  17  20  22  27  28  29  31  32  33  35  36  38  41  43 
     17   5  16   3  17  17   5   5   7  16   3  16   8  17   3   8   3   7  17   3 
     46  48  50  56  57  60  62  64  65  66  68  70  72  75  76  79  84  95  96  97 
     17   5   3   3  17   5  16  12   8   7   5  15  14   5   3  14   3  12  16   5 
    ...
    

    I truncated the output because it was too long. Now, for example,
    we see that the first test point went to node 17. We just need to use that to index into the rule sets. But a little care is needed. The 17 returned by predict is a number. The name of the ruleset is a string, so we need to use as.character to convert it.

    Pm1_rules[as.character(TestPred[1])]
                                 17 
    "glucose > 122 & glucose > 157" 
    

    We confirm:

    Pimatest[1,]
      pregnant glucose pressure triceps insulin mass pedigree age diabetes
    3        8     183       64       0       0 23.3    0.672  32      pos
    

    So yes, glucose > 122 AND glucose > 157

    You can get the rules for the other test points in the same way.