From c7f74f162583408e0b04a1d074b679fe6b418b24 Mon Sep 17 00:00:00 2001 From: Carsten Behring Date: Sat, 12 Oct 2024 11:15:56 +0000 Subject: [PATCH] test with dmatrix passing --- deps.edn | 4 +- test/scicloj/ml/text_test.clj | 210 +++++++++++++++++++++------------- 2 files changed, 131 insertions(+), 83 deletions(-) diff --git a/deps.edn b/deps.edn index ddc3560..41d0d22 100644 --- a/deps.edn +++ b/deps.edn @@ -3,7 +3,9 @@ {org.clojure/clojure {:mvn/version "1.12.0"} ml.dmlc/xgboost4j_2.12 {:mvn/version "2.1.1"} ;ml.dmlc/xgboost4j-spark_2.12 {:mvn/version "2.1.1"} ;; what for ?? - org.scicloj/metamorph.ml {:mvn/version "0.9.0"} + org.scicloj/metamorph.ml {:git/url "https://github.com/scicloj/metamorph.ml" + :git/sha "a7cf0fa4545e80e3ab30df0f81879fd6f7d930d9"} + ;{:mvn/version "0.9.0"} com.github.haifengl/smile-core {:mvn/version "2.6.0"} diff --git a/test/scicloj/ml/text_test.clj b/test/scicloj/ml/text_test.clj index e20ba39..914ca77 100644 --- a/test/scicloj/ml/text_test.clj +++ b/test/scicloj/ml/text_test.clj @@ -1,19 +1,24 @@ -(ns scicloj.metamorph.text-test - (:require [tablecloth.api :as tc] - [tablecloth.column.api :as tcc] - [scicloj.metamorph.ml.csr :as csr] +(ns scicloj.ml.text-test + (:require [clojure.data.csv :as csv] + [clojure.java.io :as io] + [clojure.string :as str] + [clojure.test :refer [deftest is]] + [scicloj.metamorph.ml.loss :as loss] [scicloj.metamorph.ml.text :as text] - [clojure.data.csv :as csv] [scicloj.ml.xgboost :as xgboost] - [scicloj.metamorph.ml.loss :as loss]) - (:import [ml.dmlc.xgboost4j.java XGBoost] + [scicloj.ml.xgboost.csr :as csr] + [tablecloth.api :as tc] + [tablecloth.column.api :as tcc]) + (:import [java.util.zip GZIPInputStream] + [ml.dmlc.xgboost4j.java XGBoost] [ml.dmlc.xgboost4j.java DMatrix DMatrix$SparseType])) + (def ds (-> - (text/->tidy-text "test/data/reviews.csv" + (text/->tidy-text (io/reader (GZIPInputStream. (io/input-stream "test/data/reviews.csv.gz"))) (fn [line] (let [splitted (first (csv/read-csv line))] @@ -22,6 +27,7 @@ ]) ) + #(str/split % #" ") :max-lines 10000 :skip-lines 1 @@ -30,13 +36,17 @@ +(def ds + (tc/select-rows ds #(not (= "" (:word %))))) + (def rnd-indexes - (-> (range 1000) shuffle)) + (-> (range 1000) (shuffle))) (def rnd-indexes-train (take 800 rnd-indexes)) + (def rnd-indexes-test (take-last 200 rnd-indexes)) @@ -52,6 +62,8 @@ text/->term-frequency text/add-word-idx)) + + (def zero-baseddocs-map-train (zipmap (-> bow-train :document distinct) @@ -157,36 +169,58 @@ (XGBoost/loadModel (java.io.ByteArrayInputStream. (:model-data model)))) -(def predition +(def predition-train + (->> + (.predict booster m-train) + (map #(int (first %))))) + +(def predition-test (->> (.predict booster m-test) (map #(int (first %))))) -(loss/classification-accuracy - (vec predition) - (vec labels-test)) + +(def train-accuracy + (loss/classification-accuracy + (vec predition-train) + (vec labels-train))) + +(def test-accuracy + (loss/classification-accuracy + (vec predition-test) + (vec labels-test))) + + +(deftest reviews-accuracy-sparse-matrix-classification + (is (< 0.95 train-accuracy)) + (is (< 0.54 test-accuracy))) + ;;=> 0.973 ;;------------------------------ - -(def result - (text/->tidy-text "test/data/small_text.csv" - (fn [line] - (let [splitted (first - (csv/read-csv line))] - (vector - (first splitted) - (dec (Integer/parseInt (second splitted)))))) - :max-lines 10000 - :skip-lines 1)) +(comment + + (def result + (text/->tidy-text "test/data/small_text.csv" + (fn [line] + (let [splitted (first + (csv/read-csv line))] + (vector + (first splitted) + (dec (Integer/parseInt (second splitted)))))) + :max-lines 10000 + :skip-lines 1)) + -(def ds (:ds result)) -(def st (:st result)) + (def ds (:ds result)) + + (def st (:st result)) -ds + ds + ;;=> _unnamed [12 4]: ;; ;; | :word | :word-index | :document | :label | @@ -204,13 +238,14 @@ ds ;; | me | 3 | 1 | 1 | ;; | ? | 4 | 1 | 1 | ;; + + (def bow + (-> ds + text/->term-frequency + text/add-word-idx)) -(def bow - (-> ds - text/->term-frequency - text/add-word-idx)) - -bow + bow + ;;=> _unnamed [11 5]: ;; ;; | :word | :document | :label | :tf | :word-idx | @@ -227,35 +262,40 @@ bow ;; | me | 1 | 1 | 1 | 8 | ;; | ? | 1 | 1 | 1 | 9 | ;; - -st + + st -(def sparse-features - (-> bow - (tc/select-columns [:document :word-idx :tf]) - (tc/rows))) + (def sparse-features + (-> bow + (tc/select-columns [:document :word-idx :tf]) + (tc/rows))) + ;;=> [[0 1 1] [0 2 1] [0 3 2] [0 4 1] [0 5 1] [0 6 1] [1 7 1] [1 5 1] [1 2 1] [1 8 1] [1 9 1]] + - -(def n-rows (inc (apply tcc/max (bow :document)))) -n-rows + (def n-rows (inc (apply tcc/max (bow :document)))) + + n-rows ;;=> 2 - -(def n-col (inc (apply max (bow :word-idx)))) -n-col + + (def n-col (inc (apply max (bow :word-idx)))) + + n-col + ;;=> 10 + + (def csr + (csr/->csr sparse-features)) -(def csr - (csr/->csr sparse-features)) - -(def dense - (csr/->dense csr n-rows n-col)) + (def dense + (csr/->dense csr n-rows n-col)) ;; 0 1 2 3 4 5 6 7 8 9 ;;=> ((0 1 1 2 1 1 1 0 0 0) ; I like fish fish and you the ;; (0 0 1 0 0 1 0 1 1 1)) ; like you do me ? - -bow + + bow + ;;=> _unnamed [11 5]: ;; ;; | :word | :document | :label | :tf | :word-idx | @@ -272,49 +312,55 @@ bow ;; | me | 1 | 1 | 1 | 8 | ;; | ? | 1 | 1 | 1 | 9 | ;; + - -(def labels - (-> - bow - (tc/group-by :document) - (tc/aggregate #(-> % :label first)) - (tc/column "summary"))) -labels + (def labels + (-> + bow + (tc/group-by :document) + (tc/aggregate #(-> % :label first)) + (tc/column "summary"))) + labels + ;;=> #tech.v3.dataset.column[2] ;; summary ;; [0, 1] + + (def m + (DMatrix. + (long-array (:row-pointers csr)) + (int-array (:column-indices csr)) + (float-array (:values csr)) + DMatrix$SparseType/CSR + n-col)) + (.setLabel m (float-array labels)) + -(def m - (DMatrix. - (long-array (:row-pointers csr)) - (int-array (:column-indices csr)) - (float-array (:values csr)) - DMatrix$SparseType/CSR - n-col)) -(.setLabel m (float-array labels)) + (def model + (xgboost/train-from-dmatrix + m + ["word"] + ["label"] + {:num-class 2} + {} + "multi:softprob")) + -(def model - (xgboost/train-from-dmatrix - m - ["word"] - ["label"] - {:num-class 2} - {} - "multi:softprob")) + (def booster + (XGBoost/loadModel + (java.io.ByteArrayInputStream. (:model-data model)))) + + (def predition + (.predict booster m)) + predition + ) -(def booster - (XGBoost/loadModel - (java.io.ByteArrayInputStream. (:model-data model)))) -(def predition - (.predict booster m)) -predition