-
Notifications
You must be signed in to change notification settings - Fork 6.8k
MXNET-873 - Bring Clojure Package Inline with New DataDesc and Layout in Scala Package #12387
Changes from all commits
637e10e
5db1306
44fa15c
a1aa0f2
25959cc
6fd245d
a53f869
f8a125b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,11 +17,12 @@ | |
|
||
(ns org.apache.clojure-mxnet.io | ||
(:refer-clojure :exclude [next]) | ||
(:require [org.apache.clojure-mxnet.base :as base] | ||
(:require [clojure.spec.alpha :as s] | ||
[org.apache.clojure-mxnet.base :as base] | ||
[org.apache.clojure-mxnet.shape :as mx-shape] | ||
[org.apache.clojure-mxnet.util :as util] | ||
[org.apache.clojure-mxnet.dtype :as dtype] | ||
[clojure.spec.alpha :as s] | ||
[org.apache.clojure-mxnet.layout :as layout] | ||
[org.apache.clojure-mxnet.ndarray :as ndarray] | ||
[org.apache.clojure-mxnet.random :as random]) | ||
(:import (org.apache.mxnet IO DataDesc DataBatch NDArray) | ||
|
@@ -57,18 +58,48 @@ | |
|
||
(defn resize-iter [iter nbatch]) | ||
|
||
(defn provide-data [pack-iterator] | ||
(defn provide-data | ||
"Provides the description of the data iterator in the form of | ||
[{:name name :shape shape-vec}]" | ||
[pack-iterator] | ||
(->> pack-iterator | ||
(.provideData) | ||
(util/scala-map->map) | ||
(mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)})))) | ||
|
||
(defn provide-label [pack-iterator] | ||
(defn provide-label | ||
"Provides the description of the label iterator in the form of | ||
[{:name name :shape shape-vec}]" | ||
[pack-iterator] | ||
(->> pack-iterator | ||
(.provideLabel) | ||
(util/scala-map->map) | ||
(mapv (fn [[k v]] {:name k :shape (mx-shape/->vec v)})))) | ||
|
||
(defn data-desc->map [dd] | ||
{:name (.name dd) | ||
:shape (mx-shape/->vec (.shape dd)) | ||
:dtype (.dtype dd) | ||
:layout (.layout dd)}) | ||
|
||
(defn provide-data-desc | ||
"Provides the Data Desc of the data iterator in the form of | ||
[{:name name :shape shape-vec :dtype dtype :layout layout}]" | ||
[pack-iterator] | ||
(->> pack-iterator | ||
(.provideDataDesc) | ||
(util/scala-vector->vec) | ||
(mapv data-desc->map))) | ||
|
||
(defn provide-label-desc | ||
"Provides the Data Desc of the label iterator in the form of | ||
[{:name name :shape shape-vec :dtype dtype :layout layout}]" | ||
[pack-iterator] | ||
(->> pack-iterator | ||
(.provideLabelDesc) | ||
(util/scala-vector->vec) | ||
(mapv data-desc->map))) | ||
|
||
(defn reset [iterator] | ||
(.reset iterator)) | ||
|
||
|
@@ -194,7 +225,8 @@ | |
(defn ndarray-iter | ||
" * NDArrayIter object in mxnet. Taking NDArray to get dataiter. | ||
* | ||
* @param data vector of iter | ||
* @param data vector of iter - Can either by in the form for [ndarray..] or | ||
* {data-desc0 ndarray0 data-desc2 ndarray2 ...} | ||
* @opts map of: | ||
* :label Same as data, but is not fed to the model during testing. | ||
* :data-batch-size Batch Size (default 1) | ||
|
@@ -213,14 +245,23 @@ | |
last-batch-handle "pad" | ||
data-name "data" | ||
label-name "label"}}] | ||
(new NDArrayIter | ||
(util/vec->indexed-seq data) | ||
(if label (util/vec->indexed-seq label) (util/empty-indexed-seq)) | ||
(int data-batch-size) | ||
shuffle | ||
last-batch-handle | ||
data-name | ||
label-name)) | ||
(if (map? data) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. does it mean if data is map? what if it is not? I don't know clojure well, just want to make sure it is intended. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes - it checks to see if the data is a map, if so, it is in the form of having a DataDesc associated with it and will be dispatched to the correct Java function signature and with scala interop. If it is not a map, it will dispatch to the original Java function signature without DataDesc. The argument checking for the correct data structures can be improved by using core.spec in Clojure. It adds gradual type checking. It is in use in the module api, but it hasn't been added in yet in this namespace. I added a line item in the TODO page for the Clojure package to capture it for later improvement work. |
||
(new NDArrayIter | ||
(.toIndexedSeq (util/list-map data)) | ||
(if label | ||
(.toIndexedSeq (util/list-map label)) | ||
(util/empty-indexed-seq)) | ||
(int data-batch-size) | ||
shuffle | ||
last-batch-handle) | ||
(new NDArrayIter | ||
(util/vec->indexed-seq data) | ||
(if label (util/vec->indexed-seq label) (util/empty-indexed-seq)) | ||
(int data-batch-size) | ||
shuffle | ||
last-batch-handle | ||
data-name | ||
label-name))) | ||
([data] | ||
(ndarray-iter data {}))) | ||
|
||
|
@@ -230,24 +271,19 @@ | |
(s/def ::name string?) | ||
(s/def ::shape vector?) | ||
(s/def ::dtype #{dtype/UINT8 dtype/INT32 dtype/FLOAT16 dtype/FLOAT32 dtype/FLOAT64}) | ||
(s/def ::layout (s/or :custom string? :standard #{layout/UNDEFINED | ||
layout/NCHW | ||
layout/NTC | ||
layout/NT | ||
layout/N})) | ||
(s/def ::data-desc (s/keys :req-un [::name ::shape] :opt-un [::dtype ::layout])) | ||
|
||
;; NCHW is N:batch size C: channel H: height W: width | ||
;;; other layouts are | ||
;; NT, TNC, nad N | ||
;; the shape length must match the lengh of the layout string size | ||
(defn data-desc | ||
([{:keys [name shape dtype layout] :as opts | ||
:or {dtype base/MX_REAL_TYPE}}] | ||
:or {dtype base/MX_REAL_TYPE | ||
layout layout/UNDEFINED}}] | ||
(util/validate! ::data-desc opts "Invalid data description") | ||
(let [sc (count shape) | ||
layout (or layout (cond | ||
(= 1 sc) "N" | ||
(= 2 sc) "NT" | ||
(= 3 sc) "TNC" | ||
(= 4 sc) "NCHW" | ||
:else (apply str (repeat sc "?"))))] | ||
(new DataDesc name (mx-shape/->shape shape) dtype layout))) | ||
(new DataDesc name (mx-shape/->shape shape) dtype layout)) | ||
([name shape] | ||
(data-desc {:name name :shape shape}))) | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,35 @@ | ||
;; | ||
;; Licensed to the Apache Software Foundation (ASF) under one or more | ||
;; contributor license agreements. See the NOTICE file distributed with | ||
;; this work for additional information regarding copyright ownership. | ||
;; The ASF licenses this file to You under the Apache License, Version 2.0 | ||
;; (the "License"); you may not use this file except in compliance with | ||
;; the License. You may obtain a copy of the License at | ||
;; | ||
;; http://www.apache.org/licenses/LICENSE-2.0 | ||
;; | ||
;; Unless required by applicable law or agreed to in writing, software | ||
;; distributed under the License is distributed on an "AS IS" BASIS, | ||
;; WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
;; See the License for the specific language governing permissions and | ||
;; limitations under the License. | ||
;; | ||
|
||
(ns org.apache.clojure-mxnet.layout | ||
(:import (org.apache.mxnet Layout))) | ||
|
||
;; | ||
;; Layout definition of DataDesc | ||
;; N Batch size | ||
;; C channels | ||
;; H Height | ||
;; W Weight | ||
;; T sequence length | ||
;; __undefined__ default value of Layout | ||
;; | ||
|
||
(def UNDEFINED (Layout/UNDEFINED)) ;"__UNDEFINED__" | ||
(def NCHW (Layout/NCHW)) ;=> "NCHW" | ||
(def NTC (Layout/NTC)) ;=> "NTC" | ||
(def NT (Layout/NT)) ;=> "NT" | ||
(def N (Layout/N)) ;=> "N |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -144,7 +144,7 @@ | |
which must be known from the rest of the net." | ||
([start {:keys [step repeat dtype] | ||
:or {step (float 1) repeat (int 1) dtype base/MX_REAL_TYPE} | ||
:as opts}] | ||
:as opts}] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. space issue... There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In this case the |
||
(Symbol/arange (float start) ($/option nil) step repeat true nil dtype)) | ||
([start] | ||
(arange-with-inference start {}))) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What does this
true
mean in here?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's for the example - it controls whether the image url displays in a popup for you view
https://github.com/apache/incubator-mxnet/blob/6fd245d1ce28a7bc09b83aeb2a1233f085210037/contrib/clojure-package/examples/pre-trained-models/src/pre_trained_models/predict_image.clj#L42
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
nit: not a big deal since it's a helper fn in an example, but it'd make sense to use an options map or kw arg (
{:display true}
or:display true
). [outside this PR though, also]There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes agree