Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SYSTEMDS-3670] Added early-stopping mechanism to tSNE #1990

Merged
merged 10 commits into from
Feb 9, 2024
66 changes: 50 additions & 16 deletions scripts/builtin/tSNE.dml
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,12 @@
# lr Learning rate
# momentum Momentum Parameter
# max_iter Number of iterations
# tol Tolerance for early stopping in gradient descent
# seed The seed used for initial values.
# If set to -1 random seeds are selected.
# is_verbose Print debug information
# print_iter Intervals of printing out the L1 norm values. Parameter not relevant if
# is_verbose = FALSE.
# -------------------------------------------------------------------------------------------
#
# OUTPUT:
Expand All @@ -42,7 +45,8 @@
# -------------------------------------------------------------------------------------------

m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity = 30,
Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Integer seed = -1, Boolean is_verbose = FALSE)
Double lr = 300., Double momentum = 0.9, Integer max_iter = 1000, Double tol = 1e-5,
Integer seed = -1, Boolean is_verbose = FALSE, Integer print_iter = 10)
return(Matrix[Double] Y)
{
d = reduced_dims
Expand All @@ -63,15 +67,55 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity
if(is_verbose)
print("starting loop....")

for (itr in 1:max_iter) {
D = distance_matrix(Y)
itr = 1

# Start first iteration out of loop as benchmark for early stopping
D = dist(Y)
Z = 1/(D + 1)
Z = Z * ZERODIAG
Q = Z/sum(Z)
W = (P - Q)*Z
sumW = rowSums(W)
g = Y * sumW - W %*% Y
dY = momentum*dY - lr*g

norm = sum(dY^2)
norm_initial = norm
norm_target = norm_initial * tol

if(is_verbose){
print("L1 Norm initial : " + norm_initial)
print("L1 Norm target : " + norm_target)
}

Y = Y + dY
Y = Y - colMeans(Y)

if (itr%%100 == 0) {
C[itr/100,] = sum(P * log(pmax(P, 1e-12) / pmax(Q, 1e-12)))
}
if (itr == 100) {
P = P/4
}
itr = itr + 1
# End of first iteration

while (itr <= max_iter & norm > norm_target) {
D = dist(Y)
Z = 1/(D + 1)
Z = Z * ZERODIAG
Q = Z/sum(Z)
W = (P - Q)*Z
sumW = rowSums(W)
g = Y * sumW - W %*% Y
dY = momentum*dY - lr*g

norm = sum(dY^2)
if(is_verbose & itr %% print_iter == 0){
print("Iteration: " + itr)
print("L1 Norm: " + norm)
}

Y = Y + dY
Y = Y - colMeans(Y)

Expand All @@ -81,20 +125,10 @@ m_tSNE = function(Matrix[Double] X, Integer reduced_dims = 2, Integer perplexity
if (itr == 100) {
P = P/4
}
itr = itr + 1
}
}

distance_matrix = function(matrix[double] X)
return (matrix[double] out)
{
# TODO consolidate with dist() builtin, but with
# better way of obtaining the diag from
n = nrow(X)
s = rowSums(X * X)
out = - 2*X %*% t(X) + s + t(s)
}


x2p = function(matrix[double] X, double perplexity, Boolean is_verbose = FALSE)
return(matrix[double] P)
{
Expand All @@ -105,7 +139,7 @@ return(matrix[double] P)
n = nrow(X)
if(is_verbose)
print(n)
D = distance_matrix(X)
D = dist(X)

P = matrix(0, rows=n, cols=n)
beta = matrix(1, rows=n, cols=1)
Expand All @@ -119,7 +153,7 @@ return(matrix[double] P)
while (mean(abs(Hdiff)) > tol & itr < 50) {
P = exp(-D * beta)
P = P * ZERODIAG
sum_Pi = rowSums(P)
sum_Pi = rowSums(P) + 1e-12
W = rowSums(P * D)
Ws = W/sum_Pi
H = log(sum_Pi) + beta * Ws
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.apache.sysds.test.AutomatedTestBase;
import org.apache.sysds.test.TestConfiguration;
import org.junit.Test;
import static org.junit.Assert.assertTrue;

import java.io.IOException;

Expand All @@ -41,12 +42,12 @@ public void setUp() {
@Test
public void testTSNECP() throws IOException {
runTSNETest(2, 30, 300.,
0.9, 1000, 42, "FALSE", ExecType.CP);
0.9, 1000, 1e-5d, 42, "FALSE", 10, ExecType.CP);
}

@SuppressWarnings("unused")
private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr,
Double momentum, Integer max_iter, Integer seed, String is_verbose, ExecType instType)
private void runTSNETest(int reduced_dims, int perplexity, double lr,
double momentum, int max_iter, double tol, int seed, String is_verbose, Integer print_iter, ExecType instType)
throws IOException
{
ExecMode platformOld = setExecMode(instType);
Expand All @@ -64,8 +65,11 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr,
"lr=" + lr,
"momentum=" + momentum,
"max_iter=" + max_iter,
"tol=" + tol,
"seed=" + seed,
"is_verbose=" + is_verbose};
"is_verbose=" + is_verbose,
"print_iter=" + print_iter
};

// The Input values are calculated using the following R script:
// TODO create via dml operations, avoid inlining data
Expand Down Expand Up @@ -403,4 +407,136 @@ private void runTSNETest(Integer reduced_dims, Integer perplexity, Double lr,
rtplatform = platformOld;
}
}


@Test
public void testTSNEEarlyStopping() throws IOException {
// Test setup guarantees early stopping.
runTSNEEarlyStoppingTest(2, 30, 300., 0.9, 1000, 1e-1, 1, "TRUE", 10, ExecType.CP);
}

@SuppressWarnings("unused")
private void runTSNEEarlyStoppingTest(
Integer reduced_dims,
Integer perplexity,
Double lr,
Double momentum,
Integer max_iter,
Double tol,
Integer seed,
String is_verbose,
Integer print_iter,
ExecType instType) throws IOException {

ExecMode platformOld = setExecMode(instType);
try
{
loadTestConfiguration(getTestConfiguration(TEST_NAME));

String HOME = SCRIPT_DIR + TEST_DIR;
fullDMLScriptName = HOME + TEST_NAME + ".dml";
programArgs = new String[]{
"-nvargs", "X=" + input("X"), "Y=" + output("Y"),
"reduced_dims=" + reduced_dims,
"perplexity=" + perplexity,
"lr=" + lr,
"momentum=" + momentum,
"max_iter=" + max_iter,
"tol=" + tol,
"seed=" + seed,
"is_verbose=" + is_verbose,
"print_iter=" + print_iter
};

// The Input values are calculated using the following dml script:
// X = rand(rows=50, cols=2, min=0, max=5, seed=1)

// Input
double[][] X = {
{-0.45700987356506406, 2.834752454661148},
{-1.3945444464226533, -3.8794723634582597},
{-1.338576451510809, 1.160918547857504},
{2.921891699889728, -1.32749074959577},
{1.3001464754324763, -1.1353208514333533},
{0.2866401390950628, 3.359214248871961},
{2.6740553056629217, -1.2030274345674852},
{1.7240446900374895, 3.4430052477647557},
{-0.3435254305219493, 4.205393963204703},
{-2.873899183923896, 1.098272406118296},
{4.890217056606042, 1.5814575251762104},
{-4.920042511612875, 4.579455675519821},
{1.439881754507784, -4.090781835042895},
{2.32372435941579, 4.823050596338641},
{-0.9864739586714544, -1.6990853495458147},
{4.605792626050157, 2.411639339263437},
{4.979120527950069, 1.7181757158820465},
{-4.423608438974177, 0.44712526968937283},
{3.4109472479162317, -3.497269670333382},
{-1.9938801849366037, -1.1880069697833906},
{3.223381639747396, 3.7784510177449793},
{2.10470587687118, 0.5415570090498525},
{2.084254693325721, 1.4369473809787037},
{-0.9957311983302795, 1.586795215124286},
{-3.7527381124013894, 4.3818220996816475},
{3.5748622228245193, 1.116518048277384},
{-2.297351475873446, -2.0179124546489047},
{-0.3438938003649259, 0.689249021371154},
{-0.8823286368673617, 1.2731356499886672},
{2.517220722615252, -2.8806532181877254},
{3.923092638022041, 4.34404320783608},
{-2.1012040153953, -4.33147229525127},
{3.5992422607685715, 2.5628828792092904},
{4.3431460760781775, -2.6869010463029754},
{-3.27506631006849, -1.1828954200032116},
{-4.3138906717810475, -3.7311556655569875},
{4.674799759142193, 3.783941497422669},
{3.561677127461424, 1.699651989293141},
{-3.0146338910401838, 3.3961817590254952},
{-4.438156472502506, 0.5926080631113129},
{-4.6425401564313615, 2.131545102584216},
{3.2975878235392244, -2.8485717910480988},
{-0.9776972765619627, 0.5292861827847535},
{-3.9770843662935915, -2.258269867772177},
{-4.22908475002643, -4.574457493889454},
{-0.28759876443714827, -0.5841999820607002},
{2.33121643992511, 1.7993339510854582},
{-1.476311475439723, 4.3511414590258894},
{4.974472387105775, -4.165990440844669},
{-4.570078514420281, 2.156235882831523}
};

writeInputMatrixWithMTD("X", X, true);

// Capture console output
setOutputBuffering(true);
String out = runTest(true, false, null, -1).toString();

// Parse and check L1 norm values
String[] lines = out.split(System.lineSeparator());
double prevL1Norm = Double.POSITIVE_INFINITY;
boolean decreasing = true;
int notDecreasingCount = 0; // Counter to track consecutive non-decreasing values
for (String line : lines) {
if (line.startsWith("L1 Norm:")) {
double l1Norm = Double.parseDouble(line.substring(9).trim());
if (l1Norm >= prevL1Norm) {
notDecreasingCount++;
if (notDecreasingCount >= 3) {
decreasing = false;
break; // Exit the loop once we've seen 3 consecutive non-decreasing values
}
} else {
notDecreasingCount = 0; // Reset the counter if the current value is decreasing
}
prevL1Norm = l1Norm;
}
}

assertTrue("L1 norm should decrease each time it is printed out", decreasing);
}
finally {
rtplatform = platformOld;
}

}
}
2 changes: 1 addition & 1 deletion src/test/scripts/functions/builtin/tSNE.dml
Original file line number Diff line number Diff line change
Expand Up @@ -20,5 +20,5 @@
#-------------------------------------------------------------

X = read($X);
Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $seed, $is_verbose)
Y = tSNE(X, $reduced_dims, $perplexity, $lr, $momentum, $max_iter, $tol, $seed, $is_verbose, $print_iter)
write(Y, $Y)
Loading