Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement DeepSeek V3/R1 #2745

Open
wants to merge 8 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .vscode/settings.json
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@
"candle-pyo3"
],
"python.testing.unittestEnabled": false,
"python.testing.pytestEnabled": true
"python.testing.pytestEnabled": true,
"rust-analyzer.cargo.features": ["cuda", "nccl"]
}
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ fancy-regex = "0.13.0"
gemm = { version = "0.17.0", features = ["wasm-simd128-enable"] }
hf-hub = "0.4.1"
half = { version = "2.3.1", features = ["num-traits", "use-intrinsics", "rand_distr"] }
float8 = { version = "0.1.3", features = ["num-traits", "rand_distr"] }
hound = "3.5.1"
image = { version = "0.25.2", default-features = false, features = ["jpeg", "png"] }
imageproc = { version = "0.24.0", default-features = false }
Expand Down
3 changes: 2 additions & 1 deletion candle-core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ metal = { workspace = true, optional = true}
cudarc = { workspace = true, optional = true }
gemm = { workspace = true }
half = { workspace = true }
float8 = { workspace = true }
intel-mkl-src = { workspace = true, optional = true }
libc = { workspace = true, optional = true }
memmap2 = { workspace = true }
Expand All @@ -42,7 +43,7 @@ criterion = { workspace = true }

[features]
default = []
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda"]
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"]
cudnn = ["cuda", "cudarc/cudnn"]
mkl = ["dep:libc", "dep:intel-mkl-src"]
accelerate = ["dep:libc", "dep:accelerate-src"]
Expand Down
6 changes: 6 additions & 0 deletions candle-core/src/convert.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
//! Implement conversion traits for tensors
use crate::{DType, Device, Error, Tensor, WithDType};
use float8::F8E4M3;
use half::{bf16, f16, slice::HalfFloatSliceExt};
use std::convert::TryFrom;

Expand Down Expand Up @@ -139,6 +140,11 @@ impl Tensor {
let vs = vs.to_vec1::<u8>()?;
f.write_all(&vs)?;
}
DType::F8E4M3 => {
for v in vs.to_vec1::<F8E4M3>()? {
f.write_u8(v.to_bits())?
}
}
}
Ok(())
}
Expand Down
117 changes: 117 additions & 0 deletions candle-core/src/cpu_backend/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
use crate::backend::{BackendDevice, BackendStorage};
use crate::op::{BinaryOpT, CmpOp, ReduceOp, UnaryOpT};
use crate::{DType, Error, IntDType, Layout, Result, Shape, WithDType};
use float8::F8E4M3;
use half::{bf16, f16};
use rayon::prelude::*;

Expand All @@ -25,6 +26,7 @@ pub enum CpuStorage {
F16(Vec<f16>),
F32(Vec<f32>),
F64(Vec<f64>),
F8E4M3(Vec<F8E4M3>),
}

#[derive(Debug, Clone)]
Expand All @@ -36,6 +38,7 @@ pub enum CpuStorageRef<'a> {
F16(&'a [f16]),
F32(&'a [f32]),
F64(&'a [f64]),
F8E4M3(&'a [F8E4M3]),
}

#[derive(Debug, Clone)]
Expand Down Expand Up @@ -1623,6 +1626,17 @@ impl CpuStorage {
.concat();
Self::F64(storages)
}
Self::F8E4M3(_) => {
let storages = storages
.iter()
.map(|s| match s {
Self::F8E4M3(s) => Ok(s.as_slice()),
_ => crate::bail!("dtype mismatch"),
})
.collect::<Result<Vec<_>>>()?
.concat();
Self::F8E4M3(storages)
}
};
Ok(s)
}
Expand All @@ -1640,6 +1654,7 @@ impl BackendStorage for CpuStorage {
Self::F16(_) => DType::F16,
Self::F32(_) => DType::F32,
Self::F64(_) => DType::F64,
Self::F8E4M3(_) => DType::F8E4M3,
}
}

Expand Down Expand Up @@ -1674,6 +1689,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, bf16::from_f64);
Ok(Self::BF16(data))
}
(Self::F8E4M3(storage), DType::BF16) => {
let data = unary_map(storage, layout, |v| bf16::from_f32(v.to_f32()));
Ok(Self::BF16(data))
}
(Self::U8(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v as f32));
Ok(Self::F16(data))
Expand Down Expand Up @@ -1702,6 +1721,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, f16::from_f64);
Ok(Self::F16(data))
}
(Self::F8E4M3(storage), DType::F16) => {
let data = unary_map(storage, layout, |v| f16::from_f32(v.to_f32()));
Ok(Self::F16(data))
}
(Self::U8(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
Expand Down Expand Up @@ -1730,6 +1753,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as f32);
Ok(Self::F32(data))
}
(Self::F8E4M3(storage), DType::F32) => {
let data = unary_map(storage, layout, |v| v.to_f32());
Ok(Self::F32(data))
}
(Self::U8(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::U8(data))
Expand Down Expand Up @@ -1758,6 +1785,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u8);
Ok(Self::U8(data))
}
(Self::F8E4M3(storage), DType::U8) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u8);
Ok(Self::U8(data))
}
(Self::U8(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
Expand Down Expand Up @@ -1786,6 +1817,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as u32);
Ok(Self::U32(data))
}
(Self::F8E4M3(storage), DType::U32) => {
let data = unary_map(storage, layout, |v| v.to_f32() as u32);
Ok(Self::U32(data))
}
(Self::U8(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
Expand Down Expand Up @@ -1814,6 +1849,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v as i64);
Ok(Self::I64(data))
}
(Self::F8E4M3(storage), DType::I64) => {
let data = unary_map(storage, layout, |v| v.to_f32() as i64);
Ok(Self::I64(data))
}
(Self::U8(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v as f64);
Ok(Self::F64(data))
Expand Down Expand Up @@ -1842,6 +1881,42 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F64(data))
}
(Self::F8E4M3(storage), DType::F64) => {
let data = unary_map(storage, layout, |v| v.to_f64());
Ok(Self::F64(data))
}
(Self::U8(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::U32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::I64(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v as f32));
Ok(Self::F8E4M3(data))
}
(Self::BF16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from(v.to_f32()));
Ok(Self::F8E4M3(data))
}
(Self::F16(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| F8E4M3::from_f32(v.to_f32()));
Ok(Self::F8E4M3(data))
}
(Self::F32(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, F8E4M3::from_f32);
Ok(Self::F8E4M3(data))
}
(Self::F64(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, F8E4M3::from_f64);
Ok(Self::F8E4M3(data))
}
(Self::F8E4M3(storage), DType::F8E4M3) => {
let data = unary_map(storage, layout, |v| v);
Ok(Self::F8E4M3(data))
}
}
}

Expand Down Expand Up @@ -1955,6 +2030,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| v.powf(e));
Ok(Self::F64(data))
}
Self::F8E4M3(storage) => {
let data = unary_map(storage, layout, |v| v.powf(F8E4M3::from_f64(e)));
Ok(Self::F8E4M3(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
Expand All @@ -1980,6 +2059,10 @@ impl BackendStorage for CpuStorage {
let data = unary_map(storage, layout, |v| elu(v, alpha));
Ok(Self::F64(data))
}
Self::F8E4M3(storage) => {
let data = unary_map(storage, layout, |v| elu(v, F8E4M3::from_f64(alpha)));
Ok(Self::F8E4M3(data))
}
Self::U8(_) => Err(Error::UnsupportedDTypeForOp(DType::U8, "elu").bt()),
Self::U32(_) => Err(Error::UnsupportedDTypeForOp(DType::U32, "elu").bt()),
Self::I64(_) => Err(Error::UnsupportedDTypeForOp(DType::I64, "elu").bt()),
Expand Down Expand Up @@ -2024,6 +2107,15 @@ impl BackendStorage for CpuStorage {
Ok(Self::F64(data))
}
}
Self::F8E4M3(storage) => {
if B::F8E4M3_VEC {
let data = unary_map_vec(storage, layout, B::f8e4m3, B::f8e4m3_vec);
Ok(Self::F8E4M3(data))
} else {
let data = unary_map(storage, layout, B::f8e4m3);
Ok(Self::F8E4M3(data))
}
}
Self::U8(storage) => {
let data = unary_map(storage, layout, B::u8);
Ok(Self::U8(data))
Expand Down Expand Up @@ -2505,6 +2597,15 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let uniform =
rand::distributions::Uniform::new(F8E4M3::from_f64(min), F8E4M3::from_f64(max));
for _i in 0..elem_count {
data.push(rng.sample::<F8E4M3, _>(uniform))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let uniform = rand::distributions::Uniform::new(min as f32, max as f32);
Expand Down Expand Up @@ -2551,6 +2652,15 @@ impl BackendDevice for CpuDevice {
}
Ok(CpuStorage::F16(data))
}
DType::F8E4M3 => {
let mut data = Vec::with_capacity(elem_count);
let normal = rand_distr::Normal::new(F8E4M3::from_f64(mean), F8E4M3::from_f64(std))
.map_err(Error::wrap)?;
for _i in 0..elem_count {
data.push(normal.sample(&mut rng))
}
Ok(CpuStorage::F8E4M3(data))
}
DType::F32 => {
let mut data = Vec::with_capacity(elem_count);
let normal =
Expand Down Expand Up @@ -2614,6 +2724,11 @@ impl BackendDevice for CpuDevice {
v.set_len(elem_count);
CpuStorage::F64(v)
}
DType::F8E4M3 => {
let mut v = Vec::with_capacity(elem_count);
v.set_len(elem_count);
CpuStorage::F8E4M3(v)
}
};
Ok(storage)
}
Expand All @@ -2626,6 +2741,7 @@ impl BackendDevice for CpuDevice {
DType::I64 => CpuStorage::I64(vec![1i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ONE; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ONE; elem_count]),
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ONE; elem_count]),
DType::F32 => CpuStorage::F32(vec![1f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![1f64; elem_count]),
};
Expand All @@ -2640,6 +2756,7 @@ impl BackendDevice for CpuDevice {
DType::I64 => CpuStorage::I64(vec![0i64; elem_count]),
DType::BF16 => CpuStorage::BF16(vec![bf16::ZERO; elem_count]),
DType::F16 => CpuStorage::F16(vec![f16::ZERO; elem_count]),
DType::F8E4M3 => CpuStorage::F8E4M3(vec![F8E4M3::ZERO; elem_count]),
DType::F32 => CpuStorage::F32(vec![0f32; elem_count]),
DType::F64 => CpuStorage::F64(vec![0f64; elem_count]),
};
Expand Down
4 changes: 4 additions & 0 deletions candle-core/src/cpu_backend/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub trait Map1 {
C::F16(vs) => Ok(C::F16(self.f(vs, layout)?)),
C::F32(vs) => Ok(C::F32(self.f(vs, layout)?)),
C::F64(vs) => Ok(C::F64(self.f(vs, layout)?)),
C::F8E4M3(vs) => Ok(C::F8E4M3(self.f(vs, layout)?)),
}
}
}
Expand All @@ -31,6 +32,7 @@ pub trait Map1Any {
C::F16(vs) => Ok(self.f(vs, layout, C::F16)?),
C::F32(vs) => Ok(self.f(vs, layout, C::F32)?),
C::F64(vs) => Ok(self.f(vs, layout, C::F64)?),
C::F8E4M3(vs) => Ok(self.f(vs, layout, C::F8E4M3)?),
}
}
}
Expand All @@ -48,6 +50,7 @@ pub trait Map2 {
(C::F16(v1), C::F16(v2)) => Ok(C::F16(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::F32(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::F64(self.f(v1, l1, v2, l2)?)),
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::F8E4M3(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
Expand All @@ -71,6 +74,7 @@ pub trait Map2U8 {
(C::F16(v1), C::F16(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F32(v1), C::F32(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F64(v1), C::F64(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
(C::F8E4M3(v1), C::F8E4M3(v2)) => Ok(C::U8(self.f(v1, l1, v2, l2)?)),
_ => Err(Error::DTypeMismatchBinaryOp {
lhs: v1.dtype(),
rhs: v2.dtype(),
Expand Down
Loading
Loading