Skip to content

Commit

Permalink
test kernel for rmsnorm
Browse files Browse the repository at this point in the history
  • Loading branch information
DhruvDh committed Mar 7, 2024
1 parent aee78b8 commit 1d49d42
Show file tree
Hide file tree
Showing 5 changed files with 89 additions and 47 deletions.
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
.DS_Store
Manifest.toml
bin/

# TODO: remove these
*.patch
*main.jl
profile*
#~*
3 changes: 3 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,12 @@ authors = ["Lukas Mayrhofer <[email protected]>"]
version = "0.1.0"

[deps]
Atomix = "a9b6321e-bd34-4604-b9c9-b65b8de01458"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
KernelAbstractions = "63c18a36-062a-441e-b654-da1e3ab1ce7c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
PProf = "e4faabce-9ead-11e9-39d9-4379958e3056"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Expand Down
4 changes: 4 additions & 0 deletions src/Llama2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,15 @@ using SIMD
using LoopVectorization
using Random
using Distributions
using KernelAbstractions
using Atomix

export ModelConfig, CharTokenizer, LanguageModel
export load_gguf_model, load_karpathy_model, encode, sample
export train

const global backend = CPU()

# quantization
include("quantization/utils.jl")
include("quantization/common.jl")
Expand Down
91 changes: 60 additions & 31 deletions src/inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ struct KVCache
end

KVCache(head_size::Int, n_heads::Int, seq_len::Int) = KVCache(
zeros(Float32, head_size, n_heads, seq_len),
zeros(Float32, seq_len, head_size, n_heads),
KernelAbstractions.zeros(backend, Float32, head_size, n_heads, seq_len),
KernelAbstractions.zeros(backend, Float32, seq_len, head_size, n_heads),
)

@kwdef struct RunState
Expand All @@ -82,26 +82,55 @@ KVCache(head_size::Int, n_heads::Int, seq_len::Int) = KVCache(
end

RunState(c::ModelConfig) = RunState(;
x = zeros(Float32, c.dim),
xb = zeros(Float32, c.dim),
xb2 = zeros(Float32, c.dim),
hb = zeros(Float32, c.hidden_dim),
hb2 = zeros(Float32, c.hidden_dim),
q = zeros(Float32, c.dim),
k = zeros(Float32, c.dim),
v = zeros(Float32, c.dim),
att = zeros(Float32, c.seq_len * c.n_heads),
logits = zeros(Float32, c.vocab_size),
kvcache_layers = [KVCache(c.dim ÷ c.n_heads, c.n_heads, c.seq_len) for _ in 1:c.n_layers],
x=KernelAbstractions.zeros(backend, Float32, c.dim),
xb=KernelAbstractions.zeros(backend, Float32, c.dim),
xb2=KernelAbstractions.zeros(backend, Float32, c.dim),
hb=KernelAbstractions.zeros(backend, Float32, c.hidden_dim),
hb2=KernelAbstractions.zeros(backend, Float32, c.hidden_dim),
q=KernelAbstractions.zeros(backend, Float32, c.dim),
k=KernelAbstractions.zeros(backend, Float32, c.dim),
v=KernelAbstractions.zeros(backend, Float32, c.dim),
att=KernelAbstractions.zeros(backend, Float32, c.seq_len * c.n_heads),
logits=KernelAbstractions.zeros(backend, Float32, c.vocab_size),
kvcache_layers=[KVCache(c.dim ÷ c.n_heads, c.n_heads, c.seq_len) for _ in 1:c.n_layers],
)

@kernel function self_dot_kernel!(result, x)
li = @index(Local, Linear)
I = @index(Global, Linear)

TILE_DIM = @uniform @groupsize()[1]

intermediate_result_tile = @localmem eltype(result) (TILE_DIM)
intermediate_result_tile[li] = x[I] * x[I]

@synchronize()
KernelAbstractions.Extras.@unroll for i in [1, 2, 4, 8, 16, 32]
if (li + i) <= TILE_DIM
intermediate_result_tile[li] += intermediate_result_tile[li+i]
end
@synchronize()
end

if li == 1
Atomix.@atomic result[1] += intermediate_result_tile[1]
end
end

@kernel function rmsnorm_inner_kernel!(o, @Const(x), @Const(weight), @Const(ss))
I = @index(Global, Linear)
@inbounds o[I] = weight[I] * (ss * x[I])
end

function rmsnorm!(o, x, weight)
ss = dot(x, x)
ss /= length(x)
ss += 1f-6
ss = 1f0 / sqrt(ss)
dotresult = KernelAbstractions.zeros(backend, Float32, 1)
self_dot_kernel!(backend, 32, size(x))(dotresult, x, ndrange=size(x))
ss = dotresult[1] / length(x)
ss += 1.0f-6
ss = 1.0f0 / sqrt(ss)
# normalize and scale
o .= weight .* (ss .* x)
rmsnorm_inner_kernel!(backend)(o, x, weight, ss, ndrange=size(weight))
# o .= weight .* (ss .* x)
return nothing
end

Expand All @@ -112,7 +141,7 @@ function rope!(x::AbstractMatrix{Float32}, pos::Int)
freq_base = 10000.0f0
freq_scale = 1.0f0

theta_scale = freq_base ^ (-inv(Float32(head_size_div2)))
theta_scale = freq_base^(-inv(Float32(head_size_div2)))

@inbounds for head in 1:n_heads
theta = freq_scale * (pos - 1)
Expand All @@ -136,7 +165,7 @@ end
function attention_weights!(att, key_cache, q)
@inbounds @fastmath for h in axes(att, 2)
for t in axes(att, 1)
s = 0f0
s = 0.0f0

for i in axes(q, 1)
s += q[i, h] * key_cache[i, h, t]
Expand All @@ -152,7 +181,7 @@ end
function combine_values!(xb, value_cache, att)
@inbounds @fastmath for h in axes(xb, 2)
for i in axes(xb, 1)
s = 0f0
s = 0.0f0

for t in axes(att, 1)
s += att[t, h] * value_cache[t, i, h]
Expand Down Expand Up @@ -238,7 +267,7 @@ end

# F.silu silu(x)=x*σ(x),where σ(x) is the logistic sigmoid
for i in 1:hidden_dim
s.hb[i] = s.hb[i] * (1f0 / (1f0 + exp(-s.hb[i])))
s.hb[i] = s.hb[i] * (1.0f0 / (1.0f0 + exp(-s.hb[i])))
end

s.hb .*= s.hb2
Expand All @@ -260,13 +289,13 @@ end
end

function sample(
model::LanguageModel,
prompt::String = "";
temperature::Float32 = 0.9f0,
stop_on_special_token = true,
max_seq_len = typemax(Int),
bos_token = true,
)
model::LanguageModel,
prompt::String="";
temperature::Float32=0.9f0,
stop_on_special_token=true,
max_seq_len=typemax(Int),
bos_token=true,
)

if !bos_token && isempty(prompt)
error("Prompt cannot be empty if bos_token = false")
Expand Down Expand Up @@ -299,11 +328,11 @@ function sample(
transformer!(token, pos, config, state, weights)
generated_seq_len += 1

if pos+1 <= length(prompt_tokens)
if pos + 1 <= length(prompt_tokens)
next = prompt_tokens[pos+1]
else
# sample the next token
if temperature == 0f0
if temperature == 0.0f0
# greedy argmax sampling
next = argmax(state.logits)
else
Expand Down
32 changes: 16 additions & 16 deletions src/load_karpathy.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,22 +9,22 @@ read_karpathy_config(f::IOStream) = ModelConfig(
)

TransformerLayerWeights(p::ModelConfig) = TransformerLayerWeights(;
rms_att_weight = zeros(Float32, p.dim),
rms_ffn_weight = zeros(Float32, p.dim),
wq = zeros(Float32, p.dim, p.dim),
wk = zeros(Float32, p.dim, p.dim),
wv = zeros(Float32, p.dim, p.dim),
wo = zeros(Float32, p.dim, p.dim),
w1 = zeros(Float32, p.dim, p.hidden_dim),
w2 = zeros(Float32, p.hidden_dim, p.dim),
w3 = zeros(Float32, p.dim, p.hidden_dim),
rms_att_weight=KernelAbstractions.zeros(backend, Float32, p.dim),
rms_ffn_weight=KernelAbstractions.zeros(backend, Float32, p.dim),
wq=KernelAbstractions.zeros(backend, Float32, p.dim, p.dim),
wk=KernelAbstractions.zeros(backend, Float32, p.dim, p.dim),
wv=KernelAbstractions.zeros(backend, Float32, p.dim, p.dim),
wo=KernelAbstractions.zeros(backend, Float32, p.dim, p.dim),
w1=KernelAbstractions.zeros(backend, Float32, p.dim, p.hidden_dim),
w2=KernelAbstractions.zeros(backend, Float32, p.hidden_dim, p.dim),
w3=KernelAbstractions.zeros(backend, Float32, p.dim, p.hidden_dim),
)

TransformerWeights(p::ModelConfig) = TransformerWeights(;
token_embedding_table = zeros(Float32, p.dim, p.vocab_size),
layers = [TransformerLayerWeights(p) for _ in 1:p.n_layers],
rms_final_weight = zeros(Float32, p.dim),
output_weight = zeros(Float32, p.dim, p.vocab_size),
token_embedding_table=KernelAbstractions.zeros(backend, Float32, p.dim, p.vocab_size),
layers=[TransformerLayerWeights(p) for _ in 1:p.n_layers],
rms_final_weight=KernelAbstractions.zeros(backend, Float32, p.dim),
output_weight=KernelAbstractions.zeros(backend, Float32, p.dim, p.vocab_size),
)

function read_karpathy_weights(f::IOStream, config::ModelConfig)
Expand Down Expand Up @@ -76,9 +76,9 @@ function load_karpathy_tokenizer(filename::AbstractString, vocab_size::Int)
end

function load_karpathy_model(
checkpoint_filename::AbstractString,
tokenizer_filename::AbstractString,
)
checkpoint_filename::AbstractString,
tokenizer_filename::AbstractString,
)

config = nothing
weights = nothing
Expand Down

0 comments on commit 1d49d42

Please sign in to comment.