-
Notifications
You must be signed in to change notification settings - Fork 321
/
Copy patheager.R
145 lines (116 loc) · 4.19 KB
/
eager.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
#' @export
as.array.python.builtin.EagerTensor <- function(x, ...) {
if (py_is_null_xptr(x))
NULL
else
x$numpy()
}
#' @export
as.array.tensorflow.python.framework.ops.EagerTensor <- as.array.python.builtin.EagerTensor
#' @export
as.array.tensorflow.python.ops.variables.Variable <- as.array.python.builtin.EagerTensor
#' @export
as.matrix.python.builtin.EagerTensor <- function(x, ...) {
if (py_is_null_xptr(x))
NULL
else {
a <- x$numpy()
if (length(dim(a)) == 2)
a
else
as.matrix(a)
}
}
#' @export
as.matrix.tensorflow.python.framework.ops.EagerTensor <- as.matrix.python.builtin.EagerTensor
#' @export
as.matrix.tensorflow.python.ops.variables.Variable <- as.matrix.python.builtin.EagerTensor
#' @export
as.integer.python.builtin.EagerTensor <- function(x, ...) {
if (py_is_null_xptr(x))
NULL
else
as.integer(as.array(x))
}
#' @export
as.integer.tensorflow.python.framework.ops.EagerTensor <- as.integer.python.builtin.EagerTensor
#' @export
as.integer.tensorflow.python.ops.variables.Variable <- as.integer.python.builtin.EagerTensor
#' @export
as.numeric.python.builtin.EagerTensor <- function(x, ...) {
if (py_is_null_xptr(x))
NULL
else
as.numeric(as.array(x))
}
#' @export
as.numeric.tensorflow.python.framework.ops.EagerTensor <- as.numeric.python.builtin.EagerTensor
#' @export
as.numeric.tensorflow.python.ops.variables.Variable <- as.numeric.python.builtin.EagerTensor
#' @export
as.double.python.builtin.EagerTensor <- function(x, ...) {
if (py_is_null_xptr(x))
NULL
else
as.double(as.array(x))
}
#' @export
as.double.tensorflow.python.framework.ops.EagerTensor <- as.double.python.builtin.EagerTensor
#' @export
as.double.tensorflow.python.ops.variables.Variable <- as.double.python.builtin.EagerTensor
#' @export
as.logical.python.builtin.EagerTensor <- function(x, ...) {
if (py_is_null_xptr(x))
NULL
else
as.logical(as.array(x))
}
#' @export
as.logical.tensorflow.python.framework.ops.EagerTensor <- as.logical.python.builtin.EagerTensor
#' @export
as.logical.tensorflow.python.ops.variables.Variable <- as.logical.python.builtin.EagerTensor
#' Creates a callable TensorFlow graph from an R function.
#'
#' `tf_function` constructs a callable that executes a TensorFlow graph created
#' by tracing the TensorFlow operations in `f`. This allows the TensorFlow
#' runtime to apply optimizations and exploit parallelism in the computation
#' defined by `f`.
#'
#' A guide to getting started with
#' [`tf.function`](https://www.tensorflow.org/api_docs/python/tf/function) can
#' be found [here](https://www.tensorflow.org/guide/function).
#'
#' @param f the function to be compiled
#' @param input_signature A possibly nested sequence of `tf$TensorSpec` objects
#' specifying the shapes and dtypes of the tensors that will be supplied to
#' this function. If `NULL`, a separate function is instantiated for each
#' inferred input signature. If `input_signature` is specified, every input to
#' `f` must be a tensor.
#' @param autograph TRUE or FALSE. If TRUE (the default), you can use tensors in
#' R control flow expressions `if`, `while`, `for` and `break` and they will
#' be traced into the tensorflow graph. A guide to getting started and
#' additional details can be found:
#' [here](https://t-kalinowski.github.io/tfautograph/)
#' @param ... additional arguments passed on to `tf.function` (vary based on
#' Tensorflow version). See
#' [here](https://www.tensorflow.org/api_docs/python/tf/function#args_1) for
#' details.
#'
#' @export
tf_function <- function(f,
input_signature = NULL,
autograph = TRUE,
...) {
if (!is.function(f))
stop("`f` must be an R function")
if (!(isTRUE(autograph) || isFALSE(autograph)))
stop("`autograph` must be TRUE or FALSE")
if (autograph) {
# Can't register tfautograph in Imports yet due to circular dependency
if(!requireNamespace("tfautograph", quietly=TRUE))
stop('"tfautograph" package required if autograph=TRUE. Please run install.packages("tfautograph")')
f <- tfautograph::autograph(f)
}
args <- list(py_func(f), input_signature, FALSE, ...)
do.call(tf$`function`, args)
}