-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathggparty-graphic-partying.Rmd
966 lines (837 loc) · 41.2 KB
/
ggparty-graphic-partying.Rmd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
---
title: "ggparty: Graphic Partying"
author: "Martin Borkovec"
date: "`r Sys.Date()`"
output:
html_document:
theme: flatly
toc: true
toc_float:
collapsed: false
smooth_scroll: false
toc_depth: 2
vignette: >
%\VignetteIndexEntry{ggparty}
%\VignetteEngine{knitr::rmarkdown}
%\VignetteEncoding{UTF-8}
---
<style>
body {
text-align: justify}
</style>
```{r setup, include = FALSE}
knitr::opts_chunk$set(
collapse = TRUE,
comment = "#>",
fig.width = 7
)
```
**ggparty** aims to extend **ggplot2** functionality to the **partykit** package. It provides the necessary tools
to create clearly structured and highly customizable visualizations for tree-objects of the class `'party'`.
# ggparty
Loading the **ggparty** package will also load **partykit** and **ggplot2** and thereby provide all necessary functions.
```{r}
library(ggparty)
```
## Motivating Example
The following plot can be created fairly easily with **ggparty**. All it takes is
an object of class `party`, some basic knowledge of **ggplot2** and comprehension of
the topics covered in this vignette.
```{r, fig.asp = 1, eval = T, echo = FALSE}
data("TeachingRatings", package = "AER")
tr <- subset(TeachingRatings, credits == "more")
tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native +
tenure, data = tr, weights = students, caseweights = FALSE)
# create dataframe with densities
dens_df <- data.frame(x_dens = numeric(), y_dens = numeric(), id = numeric(), breaks = character())
for (id in c(2, 5)) {
x_dens <- density(tr_tree[id]$data$age)$x
y_dens <- density(tr_tree[id]$data$age)$y
breaks <- rep("left", length(x_dens))
if (id == 2) breaks[x_dens > 50] <- "right"
if (id == 5) breaks[x_dens > 40] <- "right"
dens_df <- rbind(dens_df, data.frame(x_dens, y_dens, id, breaks))
}
# get the party started
ggparty(tr_tree, terminal_space = 0.4,
layout = data.frame(id = c(1, 2, 5, 7),
x = c(0.35, 0.15, 0.7, 0.8),
y = c(0.95, 0.6, 0.8, 0.55))) +
geom_edge(aes(col = factor(birth_order)),
size = 1.2,
alpha = 1,
ids = -1) +
geom_node_plot(ids = c(2,5),
gglist = list(
geom_line(data = dens_df,
aes(x = x_dens,
y = y_dens),
show.legend = FALSE,
alpha = 0.8),
geom_ribbon(data = dens_df,
aes(x = x_dens,
ymin = 0,
ymax = y_dens,
fill = breaks),
show.legend = FALSE,
alpha = 0.8),
xlab("age"),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
geom_node_plot(ids = 1,
gglist = list(geom_bar(aes(x = gender, fill = gender),
show.legend = FALSE,
alpha = .8),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
geom_node_plot(ids = 7,
gglist = list(geom_bar(aes(x = division, fill = division),
show.legend = FALSE,
alpha = .8),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
geom_node_plot(gglist = list(geom_point(aes(x = beauty,
y = eval,
col = tenure,
shape = minority),
alpha = 0.8),
theme_bw(base_size = 10),
scale_color_discrete(h.start = 100)),
scales = "fixed",
ids = "terminal",
shared_axis_labels = T,
shared_legend = T,
predict = "beauty",
predict_gpar = list(col = "blue",
size = 1.1)) +
theme(legend.position = "none")
```
The code used to create this plot can be found at the end of this document.
But first things first.
Let's recreate a simple example already used in the [partykit vignette](https://cran.r-project.org/web/packages/partykit/vignettes/partykit.pdf). If you
are not familiar with the [**parykit**](https://cran.r-project.org/web/packages/partykit/index.html) you should definitely check it out before you
work with this package.
```{r}
data("WeatherPlay", package = "partykit")
sp_o <- partysplit(1L, index = 1:3)
sp_h <- partysplit(3L, breaks = 75)
sp_w <- partysplit(4L, index = 1:2)
pn <- partynode(1L, split = sp_o, kids = list(
partynode(2L, split = sp_h, kids = list(
partynode(3L, info = "yes"),
partynode(4L, info = "no"))),
partynode(5L, info = "yes"),
partynode(6L, split = sp_w, kids = list(
partynode(7L, info = "yes"),
partynode(8L, info = "no")))))
py <- party(pn, WeatherPlay)
```
## `ggparty()`
The `ggparty()` function takes a tree of class `party` and allows us to plot it with
the help of the **ggplot2** package. To make this possible, the `'party'` object first
needs to be transformed into a `'data.frame'` and be passed to a `ggplot()` call. This is
exactly what happens when we run `ggparty()`.
```{r, results = "asis"}
is.ggplot(ggparty(py))
pander::pandoc.table(ggparty(py)$data[,1:16])
```
## Plot Data
The first 16 columns of the `'data.frame'` passed by `ggparty()` to `ggplot()` contain these
values:
* **id**... ID of the node
* **x**... X coordinate of the node
* **y**... Y coordinate of the node
* **parent**... ID of node's parent
* **birth_order**... Position relative to parent. Goes from left to right.
* **breaks_label**... String containing the corresponding split break of the parent's split variable.
* **info**... String containing the info of the node
* **info_list**... List containing the info of the node if it was a list
* **splitvar**... String containing the name of the Variable to split with. (only inner nodes)
* **level**... At which level to draw the node. (0 = root)
* **kids**... Number of node's kids
* **nodesize**... Number of rows in node's data.
* **p.value**... P value of model if present
* **horizontal**... Logical - specifies whether the tree is to be drawn horizontally or vertically. Identical for all nodes.
* **x_parent**... X coordinate of the node's parent
* **y_parent**... Y coordinate of the node's parent
The remaining columns contain lists of the node's `data` and we will need `geom_node_plot()` to work with them.
# Plotting a Tree
Every **ggparty plot starts with a call to the eponymous `ggparty()` function which requires an object of class `'party'`. To draw a tree we will need to add several of these components:
## Basic Building Blocks
* **geom_edge()** draws the edges between the nodes
* **geom_edge_label()** labels the edges with the corresponding split breaks
* **geom_node_label()** labels the nodes with the split variable, node info or anything
else. The shorthand versions of this geom **geom_node_splitvar()** and
**geom_node_info()** have the correct defaults to write the split variables in
the inner nodes resp. the info in the terminal nodes.
* **geom_node_plot()** creates a custom ggplot at the location of the node
In most cases we will probably want to draw at least edges, edge labels and node labels, so we will have to call the respective functions. The default mappings of `geom_edge()` and and `geom_edge_label()` ensure that lines between the related nodes are drawn and the corresponding split breaks are plotted at their centers.
Since the text we want to print on the nodes differs depending on the kind of node, we will call geom_node_label twice. Once for the inner nodes, to plot the split variables and once for the terminal nodes to plot the info elements of the tree, which in this case contain the play decision.
```{r Weatherplay, fig.width = 7}
ggparty(py) +
geom_edge() +
geom_edge_label() +
geom_node_label(aes(label = splitvar), ids = "inner") +
# identical to geom_node_splitvar() +
geom_node_label(aes(label = info), ids = "terminal")
# identical to geom_node_info()
```
Instead of adding `geom_node_label()` we can also add the convenience versions `geom_node_splitvar()` and `geom_node_info()` which contain the correct defaults
to plot the split variables in the inner nodes and the info in the terminal nodes.
Thanks to the ggplot2 mechanics we can now map different aspects of our plot to properties of the nodes. Whether that's the best choide in this case is a different question.
```{r, fig.width = 7}
ggparty(py) +
geom_edge() +
geom_edge_label() +
# map color to level and size to nodesize for all nodes
geom_node_splitvar(aes(col = factor(level),
size = nodesize)) +
geom_node_info(aes(col = factor(level),
size = nodesize))
```
We can create a horizontal tree simply by setting `horizontal` in `ggparty()` to `TRUE`.
```{r, fig.width = 7, eval = T}
ggparty(py, horizontal = TRUE) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_info()
```
## Additional Data {#Additional_Data}
This section is about extracting additional elements from the `'party'` object or adding new data. If you just want to know how to make pretty plots, feel free to skip forward to the next section.
If the default amount of elements extracted from the `'party'` object is not enough
for our purposes, there is a way to add more.
Setting the argument `add_vars` of the `ggparty()` call we can specify what to extract
and how to store it (affecting how we can use it later on). Let's say we want to
add for each node the information whether the split break
is closed on the right.
We can do this the following way:
```{r, eval = T}
gg <- ggparty(py, add_vars = list(right = "$node$split$right"))
gg$data$right
```
As we can see we need to pass a named `'list'` to `add_vars`. The names of the elements of the list
will become the names of the columns in the plot data and the elements
of the list need to be either a `'character'` string specifying how to extract the
desired element from each node (as seen above) or a function that will be applied consecutively to each node and each row of the
plot data. If we want to simply add something
to the plot data, so that it can be accessed by base level geoms (geoms making
up the tree) it has to be
of `length` one like in the example above. The same result can of course be achieved using a `'function:'`
```{r, eval = T}
gg <- ggparty(py, add_vars = list(right =
function(data, node) {
node$node$split$right
}
)
)
gg$data$right
```
But what if we want to add data to our node's `data` so that it is simultaneously
accessible through a single geom?
One way to do it, is to name the list element with the prefix `"nodedata_"`
and assign a `'function'` which returns a `'list'` for each node. It is important
that the lists be of
the same `length` as the lists created from the node's `data`. I.e. the new data has
to have the
same number of observations as the node's data since it needs to fit into one
`'data.frame'`. We are effectively adding columns to the node's `data`.
As we can see below, the plot
data's nodesize can be useful to make sure of this.
Once we call `geom_node_plot()` this data will
be readily available through `gglist` under its name (which we set for it as the
name of the list element) without the prefix - just like all the node's `data`.
```{r, eval = T}
gg <- ggparty(py, add_vars = list(nodedata_x_dens =
function(data, node) {
list(density(node$data$temperature,
n = data$nodesize)$x)
}
)
)
gg$data$nodedata_x_dens
```
The obvious limitation of this method is, that the number of
observations has to be identical to the `nodesize`. In this case we achieved
this by setting `n` of `density()` to the `nodesize`.
If we want to plot custom data of different dimensions we can simply supply it
via the `data` argument of the `geoms` in `gglist`. Though in that case we won't be able to
access it simultaneously with the node's `data` in the same `geom`. To ensure
correct behaviour this `'data.frame'` has to contain a column named `id` specifying the
`id` of the node it belongs to.
# Node Plots
If we want to plot the `data` contained within the individual nodes of the tree,
we need to add `geom_node_plot()` to our `ggparty()` call. To understand why this is
necessary let's reiterate what `ggparty()` does and how it uses the `ggplot()`
function. Every `ggplot()` call needs a `'data.frame'`, so as we've seen above
`ggparty()` creates one from the `'party'` object. In this `'data.frame'` every row
corresponds to a node of the tree.
Each column of this node's `data` is stored as a `'list'`in its own
column. This way it is not usable by `ggplot()`, since
`ggplot()` can't handle lists inside its data. This is where `geom_node_plot()` comes
into play and
each instance of `geom_node_plot()` creates a completely separate `ggplot()` call after
transforming all the columns containing lists of data (created by `ggparty()`)
into a new `'data.frame'` for the new separate `ggplot()` call.
All the other columns of
ggparty's `'data.frame'` (like `kids`, `parent`, etc.) get
lost in this process, since usually we will not be interested in these when
plotting the node data and they
could potentially cause naming conflicts. In case we do want to use them, there
is a [fairly easy way](#Additional_Data) to do so.
So by default we can access anything that can be found in the data slot of the
party object, the fitted_nodes and additionally if the `'party'` object contains any,
the `fitted.values` and the `residuals` of the included model.
Now let's take a look at a constparty object created from the same data.
```{r, eval = T}
n1 <- partynode(id = 1L, split = sp_o, kids = lapply(2L:4L, partynode))
t2 <- party(n1,
data = WeatherPlay,
fitted = data.frame(
"(fitted)" = fitted_node(n1, data = WeatherPlay),
"(response)" = WeatherPlay$play,
check.names = FALSE),
terms = terms(play ~ ., data = WeatherPlay)
)
t2 <- as.constparty(t2)
```
To visualize the distribution of the variable `play` we will use the `geom_node_plot()` function. It allows us to show the `data` of each node in its separate plot. For this to work, we have to specify the argument `gglist`. Basically we have to provide a `'list'` of all the `'gg'` components we would add to a `ggplot()` call on the `data` element of a node.
```{r, fig.width = 3, fig.asp = 0.8, eval = T}
ggplot(t2[2]$data) +
geom_bar(aes(x = "", fill = play),
position = position_fill()) +
xlab("play")
```
So if we were to use the above code to create the desired plot for one node, we can instead pass a `'list'` of the two components to `gglist` and `geom_node_plot` will create a version of it for every specified node (per default the `terminal` nodes). Keep in mind, that since it's a `'list'` we need to use `","` instead of `"+"` to combine the components.
```{r, fig.asp=1, fig.width = 7, eval = T}
ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
# pass list to gglist containing all ggplot components we want to plot for each
# (default: terminal) node
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
xlab("play")))
```
## Axes and Legends
Setting `shared_axis_labels` to `TRUE` allows us to use the space more efficiently
and `legend_separator = TRUE` draws a line between the tree and the legend.
```{r, fig.asp=1, fig.width = 7, eval = T}
ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
xlab("play")),
# draw only one label for each axis
shared_axis_labels = TRUE,
# draw line between tree and legend
legend_separator = TRUE
)
```
Setting `shared_legend` to `FALSE` draws an individual legend at each plot instead
of one common at the bottom of the plot. This might be necessary if we use
multiple different `geom_node_plots()` which lead to various legends. In case we want
to remove the legend all together (i.e. `theme(legend.position = "none")`)
`shared_legend` has to be set to `FALSE`.
```{r, fig.asp=1, fig.width = 7, eval = T}
ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
xlab("play")),
# draw individual legend for each plot
shared_legend = FALSE
)
```
Thanks to the versatility of **ggplot2** we are also very flexible in creating these
node plots. For example the barplot can be easily changed into a pie chart.
The argument `size` of `geom_node_plot()` can be set to `"nodesize"` which changes the
size of
the node plot relative to the number of observations in the respective node.
```{r, fig.width = 7, eval = T}
ggparty(t2) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
# draw pie charts with their size relative to nodesize
geom_node_plot(gglist = list(geom_bar(aes(x = "", fill = play),
position = position_fill()),
coord_polar("y"),
theme_void()),
size = "nodesize")
```
##Predictions
If the party object contains a model with only one predictor we can use the
argument `predict` to choose to show a prediction line. Additional arguments for
the `geom_line()` drawing this line can be passed via `perdict_gpar`.
So let's take a look at this `'lmtree'` containing linear models explaining `eval` with
`beauty`.
```{r, eval = T}
data("TeachingRatings", package = "AER")
tr <- subset(TeachingRatings, credits == "more")
tr_tree <- lmtree(eval ~ beauty | minority + age + gender + division + native +
tenure, data = tr, weights = students, caseweights = FALSE)
```
```{r}
ggparty(tr_tree) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(geom_point(aes(x = beauty,
y = eval,
col = tenure,
shape = minority),
alpha = 0.8),
theme_bw(base_size = 10)),
shared_axis_labels = TRUE,
legend_separator = TRUE,
# predict based on variable
predict = "beauty",
# graphical parameters for geom_line of predictions
predict_gpar = list(col = "blue",
size = 1.2)
)
```
In case we want to generate predictions for a more complicated model, we need to
do this beforehand and pass the new data through the `data` argument inside
`geom_node_plot()`'s `gglist`.
First the tree of class `'party'` is created using the **partykit** infrastructure.
```{r, eval = T}
data("GBSG2", package = "TH.data")
GBSG2$time <- GBSG2$time/365
library("survival")
wbreg <- function(y, x, start = NULL, weights = NULL, offset = NULL, ...) {
survreg(y ~ 0 + x, weights = weights, dist = "weibull", ...)
}
logLik.survreg <- function(object, ...)
structure(object$loglik[2], df = sum(object$df), class = "logLik")
gbsg2_tree <- mob(Surv(time, cens) ~ horTh + pnodes | age + tsize +
tgrade + progrec + estrec + menostat, data = GBSG2,
fit = wbreg, control = mob_control(minsize = 80))
```
So in this case we want to create a sequence over the range of the metric variable
`pnodes` and combine it once with the first level of the binary variable `horTh` and
once with the second. Using this data we then (in this case) need to generate predictions of
the type `"quantile"` with `p` set to `0.5`. The function `get_predictions()` can help us
with the second part since it applies a `newdata` function defined by us to each node and returns a suitable `'data.frame'`.
If we want to use it, we need to supply the `'party'` object, a function that creates
the new data from each node's `data` and optionally `predict_arg`, additional arguments to pass
to the `predict()` call.
```{r}
# function to generate newdata for predictions
generate_newdata <- function(data) {
z <- data.frame(horTh = factor(rep(c("yes", "no"),
each = length(data$pnodes))),
pnodes = rep(seq(from = min(data$pnodes),
to = max(data$pnodes),
length.out = length(data$pnodes)),
2))
z$x <- model.matrix(~ ., data = z)
z}
# convenience function to create dataframe for predictions
pred_df <- get_predictions(gbsg2_tree,
# IMPORTANT to set same ids as in geom_node_plot
# later used for plotting
ids = "terminal",
newdata_fun = generate_newdata,
predict_arg = list(type = "quantile",
p = 0.5)
)
```
The `'data.frame'` created this way can then be passed to any `'gg'` component in
`geom_node_plot()`'s `gglist`. In this case we want to draw a line for both values of
`horTh` and separate them by color.
```{r, fig.asp = 0.8, fig.width=7, eval = T}
ggparty(gbsg2_tree, terminal_space = 0.8, horizontal = TRUE) +
geom_edge() +
geom_node_splitvar() +
geom_edge_label() +
geom_node_plot(
gglist = list(geom_point(aes(y = `Surv(time, cens).time`,
x = pnodes,
col = horTh),
alpha = 0.6),
# supply pred_df as data argument of geom_line
geom_line(data = pred_df,
aes(x = pnodes,
y = prediction,
col = horTh),
size = 1.2),
theme_bw(),
ylab("Survival Time")
),
ids = "terminal", # not necessary since default
shared_axis_labels = TRUE
)
```
## Potential Pitfalls
### Combining `'gg'` Components in `gglist` with `"+"`
The object passed to `gglist` has to be a `'list'` and therefore we must not use `"+"` to
combine the components of a `geom_node_plot()` but instead `","`.
### Passing Components at the Wrong Place
As we now know, each `geom_node_plot()` is basically a completely separate plot with
its own arguments and specifications which are independent from the base plot of
the tree (i.e. the ggparty call with edges, labels, etc.). For that reason, if
for example, we want ro remove the legend of a `geom_node_plot()` we must not pass it
at the base level (as a component of the tree) but inside the `gglist` of the
`geom_node_plot()`.
# Node Labels
`geom_node_label()` is a modified version of **ggplot2**'s `geom_label()` which
allows for multi-line labels. However the basic functionality of `geom_label()` is
still present. This means that if we are content with uniform aesthetics for the
whole label, we can simply use `geom_node_label()` as we would `geom_label()` with
the only difference, that `x` and `y` are already mapped per default to the nodes
coordinates.
If we want to have to specify even less mappings, we can use
`geom_node_splitvar()` and `geom_node_info()`. These are wrappers of `geom_node_label()`
with the respective defaults to plot the `splitvar` in the inner nodes or
the `info` in the terminal nodes.
## Multi-Line Labels
`geom_node_label()` allows us to create multiline labels and specify individual
graphical parameters for each line. To do this, we must not map
anything to
`label` in the `aes()` passed to `mapping`,
but instead pass a `'list'` of `aes()` to the argument `line_list`. The order of the `'list'`
is the same as the order in which the lines will be printed. Additionally we
have to pass a `'list'` to `line_gpar`. This list must be the same `length` as `line_list`
and contain separately named `'lists'` of graphical parameters. If we don't want to
change anything for a specific line, the respective '`list'` hast to be an empty
`'list'`.
Mapping with the `mapping` argument of `geom_node_label()` still works and
affects all lines and the border together. The line specific graphical arguments
in `line_gpar` can be
used to overwrite these `mappings`.
Additionally to the usual aethetic
parameters we would use for `ggplot`'s `geom_label()` we can pass `parse` and `alignment`
through `line_gpar`. Parse is equivalent to the behaviour of `geom_label()` and `alignment` enables us to position the text at the left or right label border.
All other mappings in `line_list` will be ignored.
It is not possible to map other line specific aesthetics to variables. It is only
possible to map the aesthetics of the complete label to variables and overwrite specific lines with fixed values in `line_gpar`. (In essence replicating the condition of mapping only one line to a variable, but we won't
be able to do this for multiple lines with different mappings).
This may seem very convoluted, but keep in mind, that we only have to go
through this process if we want to address the graphical parameters of specific
lines.
### Example
To create a tree consisting of inner nodes labeled by their split variable and
terminal nodes labeled by their coefficients we can use the code found below.
First we need to extract the coefficients with the help of the `add_vars` argument
of `ggparty()`. This step is necessary so that we can later access them by the names
given to them in the `'list'` supplied to `add_vars`.
Since we want to plot different elements in the inner and terminal nodes, we
need to add `geom_node_label()` twice. The first call is for the inner nodes.
With the `aes()` passed to `mapping` we map the `color` of the labels to the `splitvar` of the node.
For this tree we want to display the split
variable in the first line, then the p-value in scientific notation in the
second line, the third line is just a spacer therefore empty and the fourth and
last line is supposed to show the ID of the node. We specify the aesthetics we
want to override in `line_gpar`.
Using the third line as a spacer and setting `alignment` to "left" we can
position the `id` of the node at the bottom left corner of the labels.
Correspondingly we can plot the labels for the terminal nodes.
```{r, fig.width= 7, fig.asp= 0.6, eval = T}
ggparty(tr_tree,
terminal_space = 0,
add_vars = list(intercept = "$node$info$coefficients[1]",
beta = "$node$info$coefficients[2]")) +
geom_edge(size = 1.5) +
geom_edge_label(colour = "grey", size = 4) +
# first label inner nodes
geom_node_label(# map color of complete label to splitvar
mapping = aes(col = splitvar),
# map content to label for each line
line_list = list(aes(label = splitvar),
aes(label = paste("p =",
formatC(p.value,
format = "e",
digits = 2))),
aes(label = ""),
aes(label = id)
),
# set graphical parameters for each line in same order
line_gpar = list(list(size = 12),
list(size = 8),
list(size = 6),
list(size = 7,
col = "black",
fontface = "bold",
alignment = "left")
),
# only inner nodes
ids = "inner") +
# next label terminal nodes
geom_node_label(# map content to label for each line
line_list = list(
aes(label = paste("beta[0] == ", round(intercept, 2))),
aes(label = paste("beta[1] == ",round(beta, 2))),
aes(label = ""),
aes(label = id)
),
# set graphical parameters for each line in same order
line_gpar = list(list(size = 12, parse = T),
list(size = 12, parse = T),
list(size = 6),
list(size = 7,
col = "black",
fontface = "bold",
alignment = "left")),
ids = "terminal",
# nudge labels towards bottom so that edge labels have enough space
# alternatively use shift argument of edge_label
nudge_y = -.05) +
# don't show legend for splitvar mapping to color since self-explanatory
theme(legend.position = "none") +
# html_documents seem to cut off a bit too much at the edges so set limits manually
coord_cartesian(xlim = c(0, 1), ylim = c(-0.1, 1.1))
```
# Layout
## Nodes
```{r, eval = T}
## Boston housing data
data("BostonHousing", package = "mlbench")
BostonHousing <- transform(BostonHousing,
chas = factor(chas, levels = 0:1, labels = c("no", "yes")),
rad = factor(rad, ordered = TRUE))
## linear model tree
bh_tree <- lmtree(medv ~ log(lstat) + I(rm^2) | zn +
indus + chas + nox + age + dis + rad + tax + crim + b + ptratio,
data = BostonHousing, minsize = 40)
```
Let's take a look at `ggparty()`'s layout system with the help of this `'lmtree'` based on `BostonHousing` data set from **mlbench**.
```{r, fig.width= 7, fig.asp=1, eval = T}
# terminal space specifies at which value of y the terminal plots begin
bh_plot <- ggparty(bh_tree, terminal_space = 0.5) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
# plot first row
geom_node_plot(gglist = list(
geom_point(aes(y = medv, x = `log(lstat)`, col = chas),
alpha = 0.6)),
# halving the height shrinks plots towards the top
height = 0.5) +
# plot second row
geom_node_plot(gglist = list(
geom_point(aes(y = medv, x = `I(rm^2)`, col = chas),
alpha = 0.6)),
height = 0.5,
# move -0.25 y to use the bottom half of the terminal space
nudge_y = -0.25)
bh_plot
```
`ggparty()` positions all the nodes within the unit square. For vertical trees the root is always at `(0.5, 1)`, for horizontal ones it is at `(0, 0.5)`. The argument `terminal_size` specifies how much room should be left for terminal plots. The default value depends on the `depth` of the supplied tree. The terminal nodes are placed at this value and in case labels are drawn, they are drawn there. In case plots are to be drawn their top borders are aligned to this value, i.e. the terminal plots `just` is not `"center"` but `"top"`. Therefore reducing the `height` of a terminal node shrinks it towards the top.
So if we want to plot multiple plots per node we have to keep this in mind and can achieve this for example like this.
The first `geom_node_plot()` only takes the argument `height = 0.5` which halves its size and effectively makes it
occupy only the upper half of the area it would normally do. For the second `geom_node_plot()` we also specify
the size to be 0.5 but additionally we have to specify `nudge_y`. Since the terminal space is set to be 0.5,
we know that the first plot now spans from 0.5 to 0.25. So we want to move the line where to place the
second plot to 0.25, i.e. nudge it from 0.5 by -0.25.
Changing the theme from the default `theme_void` to one for which gridlines are drawn
allows us to see the above described layout structure.
```{r, fig.width=7, fig.asp = 1, eval = T}
bh_plot + theme_bw()
```
We can use this information to manually set the positions of nodes. To do this
we must pass a `'data.frame'` containing the columns `id`, `x` and `y` to the `layout`
argument of `ggparty()`.
```{r, fig.width= 7, fig.asp=1, eval = T}
ggparty(bh_tree, terminal_space = 0.5,
# id specifies node; x and y values need to be between 0 and 1
layout = data.frame(id = c(1, 2),
x = c(0.7, 0.3),
y = c(1, 0.9))
) +
geom_edge() +
geom_edge_label() +
geom_node_splitvar() +
geom_node_plot(gglist = list(
geom_point(aes(y = medv, x = `log(lstat)`, col = chas),
alpha = 0.6)),
height = 0.5) +
geom_node_plot(gglist = list(
geom_point(aes(y = medv, x = `I(rm^2)`, col = chas),
alpha = 0.6)),
height = 0.5,
nudge_y = -0.25) +
theme_bw()
```
## Axes, Legends and Limits
As mentioned the nodes of the tree should always be positioned inside the unit square.
In case of a shared legend and no shared axis labels, it is plotted at
`(0.5, -0.05)` with `just = "top"`. In case shared axis labels are used, `just` changes
to `"bottom"` (i.e. the legend shifts approximately `0.05 units` downwards), and the
x axis label takes its position. Furthermore the
shared y axis label will be plotted outside the unit square. I.e. it can often
be the case that limits based on the unit square will not be sufficient to capture
all elements and `ggparty()` should be able to automatically cope with these situations.
In case you should need to adjust the x and y limits anyway, be advised to use
`coord_cartesian(xlim, ylim)` instead of `ylim` and `xlim` since the latter can easily
lead to unintended consequences by removing observations outside the plot limits.
# Autoplot Methods
The objects used in this document can also be plotted using the autoplot methods
provided by **ggparty**.
```{r, eval = T}
autoplot(py)
```
```{r, eval = T}
autoplot(t2)
```
```{r, fig.asp = 1, eval = T}
autoplot(bh_tree, plot_var = "log(lstat)", show_fit = FALSE)
autoplot(bh_tree, plot_var = "I(rm^2)", show_fit = TRUE)
```
```{r, eval = T}
autoplot(gbsg2_tree, plot_var = "pnodes")
```
```{r, fig.asp = 1, eval = T}
autoplot(tr_tree)
```
# Examples
Using the techniques covered in this document we should now be able to plot quite nice
trees of any `'party'` object without much effort. Let's take a look at a few
possibilities using the `tr_tree` we are already familiar with.
```{r, fig.width= 7, fig.asp= 1, eval = T}
asterisk_sign <- function(p_value) {
if (p_value < 0.001) return(c("***"))
if (p_value < 0.01) return(c("**"))
if (p_value < 0.05) return(c("*"))
else return("")
}
ggparty(tr_tree,
terminal_space = 0.5) +
geom_edge(size = 1.5) +
geom_edge_label(colour = "grey", size = 4) +
# plot fitted values against residuals for each terminal model
geom_node_plot(gglist = list(geom_point(aes(x = fitted_values,
y = residuals,
col = tenure,
shape = minority),
alpha = 0.8),
geom_hline(yintercept = 0),
theme_bw(base_size = 10)),
# y scale is fixed for better comparability,
# x scale is free for effecient use of space
scales = "free_x",
ids = "terminal",
shared_axis_labels = TRUE
) +
# label inner nodes
geom_node_label(aes(col = splitvar),
# label nodes with ID, split variable and p value
line_list = list(aes(label = paste("Node", id)),
aes(label = splitvar),
aes(label = asterisk_sign(p.value))
),
# set graphical parameters for each line
line_gpar = list(list(size = 8, col = "black", fontface = "bold"),
list(size = 12),
list(size = 8)
),
ids = "inner") +
# add labels for terminal nodes
geom_node_label(aes(label = paste0("Node ", id, ", N = ", nodesize)),
fontface = "bold",
ids = "terminal",
size = 3,
# 0.01 nudge_y is enough to be above the node plot since a terminal
# nodeplot's top (not center) is at the node's coordinates.
nudge_y = 0.01) +
theme(legend.position = "none")
```
This is the code for the example at the beginning of the document.
```{r, fig.asp = 1, eval = T}
# create dataframe with ids, densities and breaks
# since we are going to supply the data.frame directly to a geom inside gglist,
# we don't need to worry about the number of observations per id and only data for the ids
# used by the respective geom_node_plot() needs to be generated (2 and 5 in this case)
dens_df <- data.frame(x_dens = numeric(), y_dens = numeric(), id = numeric(), breaks = character())
for (id in c(2, 5)) {
x_dens <- density(tr_tree[id]$data$age)$x
y_dens <- density(tr_tree[id]$data$age)$y
breaks <- rep("left", length(x_dens))
if (id == 2) breaks[x_dens > 50] <- "right"
if (id == 5) breaks[x_dens > 40] <- "right"
dens_df <- rbind(dens_df, data.frame(x_dens, y_dens, id, breaks))
}
# adjust layout so that each node plot has enough space
ggparty(tr_tree, terminal_space = 0.4,
layout = data.frame(id = c(1, 2, 5, 7),
x = c(0.35, 0.15, 0.7, 0.8),
y = c(0.95, 0.6, 0.8, 0.55))) +
# map color of edges to birth_order (order from left to right)
geom_edge(aes(col = factor(birth_order)),
size = 1.2,
alpha = 1,
# exclude root so it doesn't count as it's own colour
ids = -1) +
# density plots for age splits
geom_node_plot(ids = c(2, 5),
gglist = list( # supply dens_df and plot line
geom_line(data = dens_df,
aes(x = x_dens,
y = y_dens),
show.legend = FALSE,
alpha = 0.8),
# supply dens_df and plot ribbon, map color to breaks
geom_ribbon(data = dens_df,
aes(x = x_dens,
ymin = 0,
ymax = y_dens,
fill = breaks),
show.legend = FALSE,
alpha = 0.8),
xlab("age"),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
# plot bar plot of gender at root
geom_node_plot(ids = 1,
gglist = list(geom_bar(aes(x = gender, fill = gender),
show.legend = FALSE,
alpha = .8),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
# plot bar plot of division for node 7
geom_node_plot(ids = 7,
gglist = list(geom_bar(aes(x = division, fill = division),
show.legend = FALSE,
alpha = .8),
theme_bw(),
theme(axis.title.y = element_blank())),
size = 1.5,
height = 0.5
) +
# plot terminal nodes with predictions
geom_node_plot(gglist = list(geom_point(aes(x = beauty,
y = eval,
col = tenure,
shape = minority),
alpha = 0.8),
theme_bw(base_size = 10),
scale_color_discrete(h.start = 100)),
shared_axis_labels = TRUE,
legend_separator = TRUE,
predict = "beauty",
predict_gpar = list(col = "blue",
size = 1.1)) +
# remove all legends from top level since self explanatory
theme(legend.position = "none")
```