Skip to content

Commit

Permalink
test with dmatrix passing
Browse files Browse the repository at this point in the history
  • Loading branch information
behrica committed Oct 12, 2024
1 parent 6bd248e commit c7f74f1
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 83 deletions.
4 changes: 3 additions & 1 deletion deps.edn
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@
{org.clojure/clojure {:mvn/version "1.12.0"}
ml.dmlc/xgboost4j_2.12 {:mvn/version "2.1.1"}
;ml.dmlc/xgboost4j-spark_2.12 {:mvn/version "2.1.1"} ;; what for ??
org.scicloj/metamorph.ml {:mvn/version "0.9.0"}
org.scicloj/metamorph.ml {:git/url "https://github.com/scicloj/metamorph.ml"
:git/sha "a7cf0fa4545e80e3ab30df0f81879fd6f7d930d9"}
;{:mvn/version "0.9.0"}

com.github.haifengl/smile-core {:mvn/version "2.6.0"}

Expand Down
210 changes: 128 additions & 82 deletions test/scicloj/ml/text_test.clj
Original file line number Diff line number Diff line change
@@ -1,19 +1,24 @@
(ns scicloj.metamorph.text-test
(:require [tablecloth.api :as tc]
[tablecloth.column.api :as tcc]
[scicloj.metamorph.ml.csr :as csr]
(ns scicloj.ml.text-test
(:require [clojure.data.csv :as csv]
[clojure.java.io :as io]
[clojure.string :as str]
[clojure.test :refer [deftest is]]
[scicloj.metamorph.ml.loss :as loss]
[scicloj.metamorph.ml.text :as text]
[clojure.data.csv :as csv]
[scicloj.ml.xgboost :as xgboost]
[scicloj.metamorph.ml.loss :as loss])
(:import [ml.dmlc.xgboost4j.java XGBoost]
[scicloj.ml.xgboost.csr :as csr]
[tablecloth.api :as tc]
[tablecloth.column.api :as tcc])
(:import [java.util.zip GZIPInputStream]
[ml.dmlc.xgboost4j.java XGBoost]
[ml.dmlc.xgboost4j.java DMatrix DMatrix$SparseType]))




(def ds
(->
(text/->tidy-text "test/data/reviews.csv"
(text/->tidy-text (io/reader (GZIPInputStream. (io/input-stream "test/data/reviews.csv.gz")))
(fn [line]
(let [splitted (first
(csv/read-csv line))]
Expand All @@ -22,6 +27,7 @@

])
)
#(str/split % #" ")
:max-lines 10000
:skip-lines 1

Expand All @@ -30,13 +36,17 @@



(def ds
(tc/select-rows ds #(not (= "" (:word %)))))


(def rnd-indexes
(-> (range 1000) shuffle))
(-> (range 1000) (shuffle)))

(def rnd-indexes-train
(take 800 rnd-indexes))


(def rnd-indexes-test
(take-last 200 rnd-indexes))

Expand All @@ -52,6 +62,8 @@
text/->term-frequency
text/add-word-idx))



(def zero-baseddocs-map-train
(zipmap
(-> bow-train :document distinct)
Expand Down Expand Up @@ -157,36 +169,58 @@
(XGBoost/loadModel
(java.io.ByteArrayInputStream. (:model-data model))))

(def predition
(def predition-train
(->>
(.predict booster m-train)
(map #(int (first %)))))

(def predition-test
(->>
(.predict booster m-test)
(map #(int (first %)))))


(loss/classification-accuracy
(vec predition)
(vec labels-test))

(def train-accuracy
(loss/classification-accuracy
(vec predition-train)
(vec labels-train)))

(def test-accuracy
(loss/classification-accuracy
(vec predition-test)
(vec labels-test)))


(deftest reviews-accuracy-sparse-matrix-classification
(is (< 0.95 train-accuracy))
(is (< 0.54 test-accuracy)))

;;=> 0.973

;;------------------------------


(def result
(text/->tidy-text "test/data/small_text.csv"
(fn [line]
(let [splitted (first
(csv/read-csv line))]
(vector
(first splitted)
(dec (Integer/parseInt (second splitted))))))
:max-lines 10000
:skip-lines 1))
(comment

(def result
(text/->tidy-text "test/data/small_text.csv"
(fn [line]
(let [splitted (first
(csv/read-csv line))]
(vector
(first splitted)
(dec (Integer/parseInt (second splitted))))))
:max-lines 10000
:skip-lines 1))



(def ds (:ds result))
(def st (:st result))
(def ds (:ds result))

(def st (:st result))

ds
ds

;;=> _unnamed [12 4]:
;;
;; | :word | :word-index | :document | :label |
Expand All @@ -204,13 +238,14 @@ ds
;; | me | 3 | 1 | 1 |
;; | ? | 4 | 1 | 1 |
;;

(def bow
(-> ds
text/->term-frequency
text/add-word-idx))

(def bow
(-> ds
text/->term-frequency
text/add-word-idx))

bow
bow

;;=> _unnamed [11 5]:
;;
;; | :word | :document | :label | :tf | :word-idx |
Expand All @@ -227,35 +262,40 @@ bow
;; | me | 1 | 1 | 1 | 8 |
;; | ? | 1 | 1 | 1 | 9 |
;;

st
st


(def sparse-features
(-> bow
(tc/select-columns [:document :word-idx :tf])
(tc/rows)))
(def sparse-features
(-> bow
(tc/select-columns [:document :word-idx :tf])
(tc/rows)))

;;=> [[0 1 1] [0 2 1] [0 3 2] [0 4 1] [0 5 1] [0 6 1] [1 7 1] [1 5 1] [1 2 1] [1 8 1] [1 9 1]]



(def n-rows (inc (apply tcc/max (bow :document))))
n-rows
(def n-rows (inc (apply tcc/max (bow :document))))
n-rows
;;=> 2

(def n-col (inc (apply max (bow :word-idx))))
n-col

(def n-col (inc (apply max (bow :word-idx))))

n-col

;;=> 10

(def csr
(csr/->csr sparse-features))

(def csr
(csr/->csr sparse-features))

(def dense
(csr/->dense csr n-rows n-col))
(def dense
(csr/->dense csr n-rows n-col))
;; 0 1 2 3 4 5 6 7 8 9
;;=> ((0 1 1 2 1 1 1 0 0 0) ; I like fish fish and you the
;; (0 0 1 0 0 1 0 1 1 1)) ; like you do me ?

bow

bow

;;=> _unnamed [11 5]:
;;
;; | :word | :document | :label | :tf | :word-idx |
Expand All @@ -272,49 +312,55 @@ bow
;; | me | 1 | 1 | 1 | 8 |
;; | ? | 1 | 1 | 1 | 9 |
;;



(def labels
(->
bow
(tc/group-by :document)
(tc/aggregate #(-> % :label first))
(tc/column "summary")))
labels
(def labels
(->
bow
(tc/group-by :document)
(tc/aggregate #(-> % :label first))
(tc/column "summary")))
labels
;;=> #tech.v3.dataset.column<int64>[2]
;; summary
;; [0, 1]



(def m
(DMatrix.
(long-array (:row-pointers csr))
(int-array (:column-indices csr))
(float-array (:values csr))
DMatrix$SparseType/CSR
n-col))
(.setLabel m (float-array labels))


(def m
(DMatrix.
(long-array (:row-pointers csr))
(int-array (:column-indices csr))
(float-array (:values csr))
DMatrix$SparseType/CSR
n-col))
(.setLabel m (float-array labels))


(def model
(xgboost/train-from-dmatrix
m
["word"]
["label"]
{:num-class 2}
{}
"multi:softprob"))


(def model
(xgboost/train-from-dmatrix
m
["word"]
["label"]
{:num-class 2}
{}
"multi:softprob"))

(def booster
(XGBoost/loadModel
(java.io.ByteArrayInputStream. (:model-data model))))

(def predition
(.predict booster m))
predition
)

(def booster
(XGBoost/loadModel
(java.io.ByteArrayInputStream. (:model-data model))))
(def predition
(.predict booster m))

predition



Expand Down

0 comments on commit c7f74f1

Please sign in to comment.