Skip to content

Commit

Permalink
[SYSTEMDS-3829] BERT layer backward pass
Browse files Browse the repository at this point in the history
Closes #2213
  • Loading branch information
MaximilianSchreff authored and Baunsgaard committed Feb 5, 2025
1 parent 54c8696 commit 22642a1
Show file tree
Hide file tree
Showing 77 changed files with 892 additions and 6 deletions.
182 changes: 182 additions & 0 deletions scripts/nn/layers/bert_layer.dml
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,17 @@ linear_tensor_forward = function(matrix[double] X, matrix[double] W, matrix[doub
out = matrix(out, rows=A, cols=B*C_new)
}

linear_tensor_backward = function(matrix[double] dout, matrix[double] X, matrix[double] W, matrix[double] b, int B,
int C_out, int C_in)
return (matrix[double] dX, matrix[double] dW, matrix[double] db) {
/*
* Helper function for computing linear layer with tensor input, of shape (A, B*C)
*/
A = nrow(X)
[dX, dW, db] = affine::backward(matrix(dout, rows=A*B, cols=C_out), matrix(X, rows=A*B, cols=C_in), W, b)
dX = matrix(dX, rows=A, cols=B*C_in)
}

layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C)
return (matrix[double] out, matrix[double] cache_mean, matrix[double] cache_var, matrix[double] cache_norm) {
/*
Expand All @@ -51,6 +62,27 @@ layer_norm_forward = function(matrix[double] X, matrix[double] gamma, matrix[dou
out = matrix(t(batch_norm_out), rows=A, cols=B*C)
}

layer_norm_backward = function(matrix[double] dout, matrix[double] cache_mean, matrix[double] cache_var,
matrix[double] cache_norm, matrix[double] X, matrix[double] gamma, matrix[double] beta, double epsilon, int B, int C)
return (matrix[double] dX, matrix[double] dgamma, matrix[double] dbeta) {
/*
* Helper function for computing layer norm via 1D batch norm with tensor input, of shpae (A, B*C)
*/
A = nrow(X)
batch_norm_input = t(matrix(X, rows=A*B, cols=C))
batch_norm_doutput = t(matrix(dout, rows=A*B, cols=C))
# EMA matrices, updated EMA matrices and out matrix are unused and thus empty matrices will be provided
empty_mat = matrix(0, rows=1, cols=1)
[batch_norm_dX, unused1, unused2] = batch_norm::backward(
batch_norm_doutput,
empty_mat, empty_mat, empty_mat,
cache_mean, cache_var, cache_norm,
batch_norm_input, t(gamma), t(beta), "train", empty_mat, empty_mat, 0.0, epsilon)
dX = matrix(t(batch_norm_dX), rows=A, cols=B*C)
dgamma = t(rowSums(batch_norm_doutput * cache_norm))
dbeta = t(rowSums(batch_norm_doutput))
}

forward = function(matrix[double] states,
int H, int T, int d, int I,
matrix[double] W_Q, matrix[double] b_Q,
Expand Down Expand Up @@ -184,3 +216,153 @@ forward = function(matrix[double] states,
[out_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2] = layer_norm_forward(
out_states, gamma_ln2, beta_ln2, epsilon_ln, T, D)
}

backward = function(matrix[double] dout_states,
matrix[double] dropout_mask_attention,
matrix[double] dropout_mask_output_1,
matrix[double] dropout_mask_output_2,
matrix[double] cache_mean_ln1, matrix[double] cache_var_ln1, matrix[double] cache_norm_ln1,
matrix[double] cache_mean_ln2, matrix[double] cache_var_ln2, matrix[double] cache_norm_ln2,
list[unknown] outputs,
matrix[double] states,
int H, int T, int d, int I,
matrix[double] W_Q, matrix[double] b_Q,
matrix[double] W_K, matrix[double] b_K,
matrix[double] W_V, matrix[double] b_V,
matrix[double] W_context, matrix[double] b_context,
matrix[double] W_intermediate, matrix[double] b_intermediate,
matrix[double] W_out, matrix[double] b_out,
double dropout_p_attention,
double dropout_p_output,
double epsilon_ln,
matrix[double] gamma_ln1, matrix[double] beta_ln1,
matrix[double] gamma_ln2, matrix[double] beta_ln2,
string activation)
return (matrix[double] din_states,
matrix[double] dW_Q, matrix[double] db_Q,
matrix[double] dW_K, matrix[double] db_K,
matrix[double] dW_V, matrix[double] db_V,
matrix[double] dW_context, matrix[double] db_context,
matrix[double] dW_intermediate, matrix[double] db_intermediate,
matrix[double] dW_out, matrix[double] db_out,
matrix[double] dgamma_ln1, matrix[double] dbeta_ln1,
matrix[double] dgamma_ln2, matrix[double] dbeta_ln2) {
/*
* Computes the backward pass for a layer of the BERT transformer architecture.
*
* Inputs (B: Batch size, T: Sequence length, D: Embedding length, H: Heads):
* - dout_states: Gradients w.r.t. output states, of shape (B, T*D)
* - dropout_mask_attention: Dropout mask used on attention, of shape (B, H*T*T)
* - dropout_mask_output_1: Dropout mask used on attention output, of shape (B, T*D)
* - dropout_mask_output_2: Dropout mask used on attention output, of shape (B, T*D)
* - cache_mean_ln1: Cached mean from layer norm 1, of shape (1, B*T)
* - cache_var_ln1: Cached mean from layer norm 1, of shape (1, B*T)
* - cache_norm_ln1: Cached mean from layer norm 1, of shape (1, B*T)
* - cache_mean_ln2: Cached mean from layer norm 2, of shape (1, B*T)
* - cache_var_ln2: Cached mean from layer norm 2, of shape (1, B*T)
* - cache_norm_ln2: Cached mean from layer norm 2, of shape (1, B*T)
* - outputs: list of relevant outputs from forward pass
* with the following order/content:
* -> 1: Output of linear query layer, of shape (B, T*D).
* -> 2: Output of linear key layer, of shape (B, T*D).
* -> 3: Output of linear value layer, of shape (B, T*D).
* -> 4: Output context of attention layer, of shape (B, T*D).
* -> 5: Output attention of attention layer, of shape (B, T*D).
* -> 6: Output of residual pass 1, of shape (B, T*D).
* -> 7: Output of layer norm 1, of shape (B, T*D).
* -> 8: Output of intermediate linear layer, of shape (B, T*I).
* -> 9: Output of activation layer, of shape (B, T*I).
* -> 10: Output of residual pass 2, of shape (B, T*D).
* - states: Hidden states, of shape (B, T*D).
* - H: Head count.
* - T: Sequence length.
* - d: Embedding length of single token per head with d*H = D.
* - I: Intemediate embedding length.
* - W_Q: Weights for linear query layer, of shape (D, D).
* - b_Q: Biases for linear query layer, of shape (1, D).
* - W_K: Weights for linear key layer, of shape (D, D).
* - b_K: Biases for linear key layer, of shape (1, D).
* - W_V: Weights for linear value layer, of shape (D, D).
* - b_V: Biases for linear value layer, of shape (1, D).
* - W_context: Weights for linear output layer on context, of shape (D, D).
* - b_context: Biases for linear output layer on context, of shape (1, D).
* - W_intermediate: Weights for intermediate linear layer, of shape (D, I).
* - b_intermediate: Biases for intermediate linear layer, of shape (1, I).
* - W_out: Weights for last linear output layer, of shape (D, D).
* - b_out: Biases for last linear output layer, of shape (1, D).
* - dropout_p_attention: Probability for dropout on attention.
* - dropout_p_output: Probability for dropout on output.
* - epsilon_ln: Epsilon value for layer norm.
* - gamma_ln1: Gamma params for layer norm 1, of shape (1, D).
* - beta_ln1: Beta params for layer norm 1, of shape (1, D).
* - gamma_ln2: Gamma params for layer norm 2, of shape (1, D).
* - beta_ln2: Beta params for layer norm 2, of shape (1, D).
* - activation: String specifying type of activation to use.
* Can be tanh or gelu.
*
* Outputs:
* - din_states: Gradients w.r.t. hidden input states, of shape (B, T*D).
* - W_Q: Gradients w.r.t. weights for linear query layer, of shape (D, D).
* - b_Q: Gradients w.r.t. biases for linear query layer, of shape (1, D).
* - W_K: Gradients w.r.t. weights for linear key layer, of shape (D, D).
* - b_K: Gradients w.r.t. biases for linear key layer, of shape (1, D).
* - W_V: Gradients w.r.t. weights for linear value layer, of shape (D, D).
* - b_V: Gradients w.r.t. biases for linear value layer, of shape (1, D).
* - W_context: Gradients w.r.t. weights for linear output layer on context, of shape (D, D).
* - b_context: Gradients w.r.t. biases for linear output layer on context, of shape (1, D).
* - W_intermediate: Gradients w.r.t. weights for intermediate linear layer, of shape (D, I).
* - b_intermediate: Gradients w.r.t. biases for intermediate linear layer, of shape (1, I).
* - W_out: Gradients w.r.t. weights for last linear output layer, of shape (D, D).
* - b_out: Gradients w.r.t. biases for last linear output layer, of shape (1, D).
*/
# Embedding dim
D = d * H

# Layer norm 2 for each token
[dout_states, dgamma_ln2, dbeta_ln2] = layer_norm_backward(
dout_states, cache_mean_ln2, cache_var_ln2, cache_norm_ln2, as.matrix(outputs[10]), gamma_ln2, beta_ln2, epsilon_ln, T, D)
# Save dout_states for residual pass
dout_states_identity_2 = dout_states
# Dropout on output 2
if (dropout_p_output > 0.0) {
dout_states = dropout::backward(dout_states, matrix(0, 1, 1), dropout_p_output, dropout_mask_output_2)
}
# Final linear output layer
[dout_states, dW_out, db_out] = linear_tensor_backward(dout_states, as.matrix(outputs[9]), W_out, b_out, T, D, I)

# Activation
if (activation == "gelu") {
dout_states = gelu::backward(dout_states, as.matrix(outputs[8]))
} else if (activation == "tanh") {
dout_states = tanh::backward(dout_states, as.matrix(outputs[8]))
}
# Linear layer of intermediate part
[dout_states, dW_intermediate, db_intermediate] = linear_tensor_backward(dout_states, as.matrix(outputs[7]), W_intermediate,
b_intermediate, T, I, D)
# Residual pass 2
dout_states = dout_states + dout_states_identity_2

# Layer norm 1 for each token
[dout_states, dgamma_ln1, dbeta_ln1] = layer_norm_backward(
dout_states, cache_mean_ln1, cache_var_ln1, cache_norm_ln1, as.matrix(outputs[6]), gamma_ln1, beta_ln1, epsilon_ln, T, D)
# Save dout_states for residual pass
dout_states_identity_1 = dout_states

# Dropout on output 1
if (dropout_p_output > 0.0) {
dout_states = dropout::backward(dout_states, matrix(0, 1, 1), dropout_p_output, dropout_mask_output_1)
}
# Linear layer on attention output (output layer)
[dcontext, dW_context, db_context] = linear_tensor_backward(dout_states, as.matrix(outputs[4]), W_context, b_context, T, D, D)

# Multi-head self attention
[dQ, dK, dV] = attention::backward(dcontext, dropout_mask_attention, as.matrix(outputs[5]), as.matrix(outputs[1]),
as.matrix(outputs[2]), as.matrix(outputs[3]), H, T, d, dropout_p_attention)

# Linear layers for Q, K, V
[dstates_Q, dW_Q, db_Q] = linear_tensor_backward(dQ, states, W_Q, b_Q, T, D, D)
[dstates_K, dW_K, db_K] = linear_tensor_backward(dK, states, W_K, b_K, T, D, D)
[dstates_V, dW_V, db_V] = linear_tensor_backward(dV, states, W_V, b_V, T, D, D)
# Add paths + residual pass 1
din_states = dstates_Q + dstates_K + dstates_V + dout_states_identity_1
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,15 @@

public class BertLayerTest extends AutomatedTestBase{
private static final String TEST_NAME_FORWARD = "bert_layer_forward";
private static final String TEST_NAME_BACKWARD = "bert_layer_backward";
private static final String TEST_DIR = "applications/nn/component/";
private static final String RESOURCE_DIR = "src/test/resources/component/transformers/bert_layer/";

@Override
public void setUp() {
TestUtils.clearAssertionInformation();
addTestConfiguration(TEST_NAME_FORWARD, new TestConfiguration(TEST_DIR, TEST_NAME_FORWARD));
addTestConfiguration(TEST_NAME_BACKWARD, new TestConfiguration(TEST_DIR, TEST_NAME_BACKWARD));
}

@Test
Expand All @@ -47,6 +49,18 @@ public void testBertLayerForwardNormalGelu() {
1e-5, true);
}

@Test
public void testBertLayerBackwardNormalGelu() {
runBertLayerTest("test3", 2, 3, 8, 2, 4, 5, "gelu", 0, TEST_NAME_BACKWARD,
1e-4, false);
}

@Test
public void testBertLayerBackwardSameDimsTanh() {
runBertLayerTest("test4", 4, 4, 4, 2, 2, 4, "tanh", 0, TEST_NAME_BACKWARD,
1e-4, false);
}

private void runBertLayerTest(String testSuffix, int batchSize, int seqLength, int embeddingDim, int numHeads,
int perHeadEmbeddingDim, int intermediateEmbeddingDim, String activation, int debug, String testname, double precision,
boolean isForward) {
Expand Down Expand Up @@ -88,6 +102,68 @@ private void runBertLayerTest(String testSuffix, int batchSize, int seqLength, i
output("states_error"),
output("attention_error"),
};
} else {
programArgs = new String[] {
"-stats", "-args",
String.valueOf(debug), String.valueOf(batchSize),
String.valueOf(seqLength), String.valueOf(embeddingDim),
String.valueOf(numHeads), String.valueOf(perHeadEmbeddingDim),
String.valueOf(intermediateEmbeddingDim), activation,
RESOURCE_DIR + "input_states_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_Q_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_Q_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_K_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_K_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_V_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_V_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_context_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_context_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_intermediate_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_intermediate_" + testSuffix + ".csv",
RESOURCE_DIR + "input_W_out_" + testSuffix + ".csv",
RESOURCE_DIR + "input_b_out_" + testSuffix + ".csv",
RESOURCE_DIR + "input_gamma_ln1_" + testSuffix + ".csv",
RESOURCE_DIR + "input_beta_ln1_" + testSuffix + ".csv",
RESOURCE_DIR + "input_gamma_ln2_" + testSuffix + ".csv",
RESOURCE_DIR + "input_beta_ln2_" + testSuffix + ".csv",
RESOURCE_DIR + "output_states_" + testSuffix + ".csv",
RESOURCE_DIR + "output_attention_" + testSuffix + ".csv",
RESOURCE_DIR + "input_dstates_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dstates_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dW_Q_" + testSuffix + ".csv",
RESOURCE_DIR + "output_db_Q_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dW_K_" + testSuffix + ".csv",
RESOURCE_DIR + "output_db_K_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dW_V_" + testSuffix + ".csv",
RESOURCE_DIR + "output_db_V_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dW_context_" + testSuffix + ".csv",
RESOURCE_DIR + "output_db_context_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dW_intermediate_" + testSuffix + ".csv",
RESOURCE_DIR + "output_db_intermediate_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dW_out_" + testSuffix + ".csv",
RESOURCE_DIR + "output_db_out_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dgamma_ln1_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dbeta_ln1_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dgamma_ln2_" + testSuffix + ".csv",
RESOURCE_DIR + "output_dbeta_ln2_" + testSuffix + ".csv",
output("din_error"),
output("dW_Q_error"),
output("db_Q_error"),
output("dW_K_error"),
output("db_K_error"),
output("dW_V_error"),
output("db_V_error"),
output("dW_context_error"),
output("db_context_error"),
output("dW_intermediate_error"),
output("db_intermediate_error"),
output("dW_out_error"),
output("db_out_error"),
output("dgamma_ln1_error"),
output("dbeta_ln1_error"),
output("dgamma_ln2_error"),
output("dbeta_ln2_error"),
};
}

// Run the test
Expand All @@ -100,12 +176,40 @@ private void runBertLayerTest(String testSuffix, int batchSize, int seqLength, i
double attentionMaxError = (Double) readDMLScalarFromOutputDir("attention_error").values().toArray()[0];
assert attentionMaxError < precision;
} else {
double dqueryMaxError = (Double) readDMLScalarFromOutputDir("dquery_error").values().toArray()[0];
assert dqueryMaxError < precision;
double dkeyMaxError = (Double) readDMLScalarFromOutputDir("dkey_error").values().toArray()[0];
assert dkeyMaxError < precision;
double dvalueMaxError = (Double) readDMLScalarFromOutputDir("dvalue_error").values().toArray()[0];
assert dvalueMaxError < precision;
double dinMaxError = (Double) readDMLScalarFromOutputDir("din_error").values().toArray()[0];
assert dinMaxError < precision;
double dWQMaxError = (Double) readDMLScalarFromOutputDir("dW_Q_error").values().toArray()[0];
assert dWQMaxError < precision;
double dbQMaxError = (Double) readDMLScalarFromOutputDir("db_Q_error").values().toArray()[0];
assert dbQMaxError < precision;
double dWKMaxError = (Double) readDMLScalarFromOutputDir("dW_K_error").values().toArray()[0];
assert dWKMaxError < precision;
double dbKMaxError = (Double) readDMLScalarFromOutputDir("db_K_error").values().toArray()[0];
assert dbKMaxError < precision;
double dWVMaxError = (Double) readDMLScalarFromOutputDir("dW_V_error").values().toArray()[0];
assert dWVMaxError < precision;
double dbVMaxError = (Double) readDMLScalarFromOutputDir("db_V_error").values().toArray()[0];
assert dbVMaxError < precision;
double dWContextMaxError = (Double) readDMLScalarFromOutputDir("dW_context_error").values().toArray()[0];
assert dWContextMaxError < precision;
double dbContextMaxError = (Double) readDMLScalarFromOutputDir("db_context_error").values().toArray()[0];
assert dbContextMaxError < precision;
double dWIntermediateMaxError = (Double) readDMLScalarFromOutputDir("dW_intermediate_error").values().toArray()[0];
assert dWIntermediateMaxError < precision;
double dbIntermediateMaxError = (Double) readDMLScalarFromOutputDir("db_intermediate_error").values().toArray()[0];
assert dbIntermediateMaxError < precision;
double dWOutMaxError = (Double) readDMLScalarFromOutputDir("dW_out_error").values().toArray()[0];
assert dWOutMaxError < precision;
double dbOutMaxError = (Double) readDMLScalarFromOutputDir("db_out_error").values().toArray()[0];
assert dbOutMaxError < precision;
double dgammaLn1MaxError = (Double) readDMLScalarFromOutputDir("dgamma_ln1_error").values().toArray()[0];
assert dgammaLn1MaxError < precision;
double dbetaLn1MaxError = (Double) readDMLScalarFromOutputDir("dbeta_ln1_error").values().toArray()[0];
assert dbetaLn1MaxError < precision;
double dgammaLn2MaxError = (Double) readDMLScalarFromOutputDir("dgamma_ln2_error").values().toArray()[0];
assert dgammaLn2MaxError < precision;
double dbetaLn2MaxError = (Double) readDMLScalarFromOutputDir("dbeta_ln2_error").values().toArray()[0];
assert dbetaLn2MaxError < precision;
}
} catch (Throwable ex) {
ex.printStackTrace(System.out); // Log or debug all exceptions or errors
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
0.138744,-0.123685,0.009826,-0.317605,-0.071714,-0.068100,-0.318219,-0.040798
0.021178,-0.289780,-0.030501,-0.167612,0.193891,-0.069418,-0.023860,-0.157829
-0.172509,-0.075206,0.071553,0.240736,0.191146,-0.317261,0.310922,0.282720
0.167298,0.075574,0.224803,-0.002292,-0.340978,-0.305271,-0.144212,-0.285705
-0.339146,-0.230328,0.334902,-0.175732,0.220540,-0.055324,0.319260,0.037938
-0.209553,-0.018144,0.224526,-0.270932,-0.276659,0.004572,0.128041,-0.074023
-0.088505,0.253091,0.335668,-0.330874,-0.074745,-0.160610,-0.319068,0.252477
-0.172221,-0.036345,-0.025570,-0.298402,-0.143356,0.133183,0.223692,0.098693
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
0.093586,0.192278,0.357936,0.249658
-0.084230,-0.296152,0.186956,0.104651
-0.082281,0.183296,-0.494868,-0.390042
-0.228878,0.252854,-0.324348,-0.287910
Loading

0 comments on commit 22642a1

Please sign in to comment.