diff --git a/Cargo.lock b/Cargo.lock index c4c4a20d2f1e..7983da26c257 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2409,6 +2409,40 @@ dependencies = [ "spin 0.5.2", ] +[[package]] +name = "lber" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2df7f9fd9f64cf8f59e1a4a0753fe7d575a5b38d3d7ac5758dcee9357d83ef0a" +dependencies = [ + "bytes", + "nom", +] + +[[package]] +name = "ldap3" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "166199a8207874a275144c8a94ff6eed5fcbf5c52303e4d9b4d53a0c7ac76554" +dependencies = [ + "async-trait", + "bytes", + "futures", + "futures-util", + "lazy_static", + "lber", + "log", + "native-tls", + "nom", + "percent-encoding", + "thiserror", + "tokio", + "tokio-native-tls", + "tokio-stream", + "tokio-util", + "url", +] + [[package]] name = "leaky-bucket" version = "1.1.2" @@ -5448,6 +5482,7 @@ dependencies = [ "hash-ids", "juniper", "lazy_static", + "ldap3", "regex", "serde", "strum 0.24.1", @@ -5487,6 +5522,7 @@ dependencies = [ "juniper_axum", "juniper_graphql_ws", "lazy_static", + "ldap3", "lettre", "logkit", "mime_guess", diff --git a/Cargo.toml b/Cargo.toml index 9488a6eca551..1d8a9d22ef9c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -67,6 +67,7 @@ mime_guess = "2.0.4" assert_matches = "1.5" insta = "1.34.0" logkit = "0.3" +ldap3 = "0.11.0" async-openai = "0.20" tracing-test = "0.2" clap = "4.3.0" diff --git a/ee/tabby-db/src/lib.rs b/ee/tabby-db/src/lib.rs index f40196ddf903..a73465aff58a 100644 --- a/ee/tabby-db/src/lib.rs +++ b/ee/tabby-db/src/lib.rs @@ -8,6 +8,7 @@ pub use email_setting::EmailSettingDAO; pub use integrations::IntegrationDAO; pub use invitations::InvitationDAO; pub use job_runs::JobRunDAO; +pub use ldap_credential::LdapCredentialDAO; pub use notifications::NotificationDAO; pub use oauth_credential::OAuthCredentialDAO; pub use provided_repositories::ProvidedRepositoryDAO; diff --git a/ee/tabby-schema/Cargo.toml b/ee/tabby-schema/Cargo.toml index a6849542d615..30eb999a675a 100644 --- a/ee/tabby-schema/Cargo.toml +++ b/ee/tabby-schema/Cargo.toml @@ -30,6 +30,7 @@ validator = { version = "0.18.1", features = ["derive"] } regex.workspace = true hash-ids.workspace = true url.workspace = true +ldap3.workspace = true [dev-dependencies] tabby-db = { path = "../../ee/tabby-db", features = ["testutils"]} diff --git a/ee/tabby-schema/graphql/schema.graphql b/ee/tabby-schema/graphql/schema.graphql index fef920d952c2..b25f4d292225 100644 --- a/ee/tabby-schema/graphql/schema.graphql +++ b/ee/tabby-schema/graphql/schema.graphql @@ -10,6 +10,13 @@ enum AuthMethod { LOGIN } +enum AuthProviderKind { + OAUTH_GITHUB + OAUTH_GOOGLE + OAUTH_GITLAB + LDAP +} + "Represents the kind of context source." enum ContextSourceKind { GIT @@ -61,6 +68,12 @@ enum Language { OTHER } +enum LdapEncryptionKind { + NONE + START_TLS + LDAPS +} + enum LicenseStatus { OK EXPIRED @@ -231,6 +244,19 @@ input UpdateIntegrationInput { kind: IntegrationKind! } +input UpdateLdapCredentialInput { + host: String! + port: Int! + bindDn: String! + bindPassword: String + baseDn: String! + userFilter: String! + encryption: LdapEncryptionKind! + skipTlsVerify: Boolean! + emailAttribute: String! + nameAttribute: String +} + input UpdateMessageInput { id: ID! threadId: ID! @@ -288,6 +314,10 @@ interface User { """ scalar DateTime +type AuthProvider { + kind: AuthProviderKind! +} + type CompletionStats { start: DateTime! end: DateTime! @@ -464,6 +494,20 @@ type JobStats { pending: Int! } +type LdapCredential { + host: String! + port: Int! + bindDn: String! + baseDn: String! + userFilter: String! + encryption: LdapEncryptionKind! + skipTlsVerify: Boolean! + emailAttribute: String! + nameAttribute: String + createdAt: DateTime! + updatedAt: DateTime! +} + type LicenseInfo { type: LicenseType! status: LicenseStatus! @@ -574,6 +618,7 @@ type Mutation { updateUserName(id: ID!, name: String!): Boolean! register(email: String!, password1: String!, password2: String!, invitationCode: String, name: String!): RegisterResponse! tokenAuth(email: String!, password: String!): TokenAuthResponse! + tokenAuthLdap(userId: String!, password: String!): TokenAuthResponse! verifyToken(token: String!): Boolean! refreshToken(refreshToken: String!): RefreshTokenResponse! createInvitation(email: String!): ID! @@ -585,6 +630,9 @@ type Mutation { deleteInvitation(id: ID!): ID! updateOauthCredential(input: UpdateOAuthCredentialInput!): Boolean! deleteOauthCredential(provider: OAuthProvider!): Boolean! + testLdapConnection(input: UpdateLdapCredentialInput!): Boolean! + updateLdapCredential(input: UpdateLdapCredentialInput!): Boolean! + deleteLdapCredential: Boolean! updateEmailSetting(input: EmailSettingInput!): Boolean! updateSecuritySetting(input: SecuritySettingInput!): Boolean! updateNetworkSetting(input: NetworkSettingInput!): Boolean! @@ -716,8 +764,10 @@ type Query { * `func_name lang:go` """ repositoryGrep(kind: RepositoryKind!, id: ID!, rev: String, query: String!): RepositoryGrepOutput! + authProviders: [AuthProvider!]! oauthCredential(provider: OAuthProvider!): OAuthCredential oauthCallbackUrl(provider: OAuthProvider!): String! + ldapCredential: LdapCredential serverInfo: ServerInfo! license: LicenseInfo! jobs: [String!]! diff --git a/ee/tabby-schema/src/dao.rs b/ee/tabby-schema/src/dao.rs index 60b7c64c8073..117bd62d5274 100644 --- a/ee/tabby-schema/src/dao.rs +++ b/ee/tabby-schema/src/dao.rs @@ -2,19 +2,20 @@ use anyhow::bail; use hash_ids::HashIds; use lazy_static::lazy_static; use tabby_db::{ - EmailSettingDAO, IntegrationDAO, InvitationDAO, JobRunDAO, NotificationDAO, OAuthCredentialDAO, - ServerSettingDAO, ThreadDAO, ThreadMessageAttachmentClientCode, ThreadMessageAttachmentCode, - ThreadMessageAttachmentDoc, ThreadMessageAttachmentIssueDoc, ThreadMessageAttachmentPullDoc, - ThreadMessageAttachmentWebDoc, UserEventDAO, + EmailSettingDAO, IntegrationDAO, InvitationDAO, JobRunDAO, LdapCredentialDAO, NotificationDAO, + OAuthCredentialDAO, ServerSettingDAO, ThreadDAO, ThreadMessageAttachmentClientCode, + ThreadMessageAttachmentCode, ThreadMessageAttachmentDoc, ThreadMessageAttachmentIssueDoc, + ThreadMessageAttachmentPullDoc, ThreadMessageAttachmentWebDoc, UserEventDAO, }; use crate::{ + auth::LdapEncryptionKind, integration::{Integration, IntegrationKind, IntegrationStatus}, interface::UserValue, notification::{Notification, NotificationRecipient}, repository::RepositoryKind, schema::{ - auth::{self, OAuthCredential, OAuthProvider}, + auth::{self, LdapCredential, OAuthCredential, OAuthProvider}, email::{AuthMethod, EmailSetting, Encryption}, job, repository::{ @@ -67,6 +68,26 @@ impl TryFrom for OAuthCredential { } } +impl TryFrom for LdapCredential { + type Error = anyhow::Error; + + fn try_from(val: LdapCredentialDAO) -> Result { + Ok(LdapCredential { + host: val.host, + port: val.port as i32, + bind_dn: val.bind_dn, + base_dn: val.base_dn, + user_filter: val.user_filter, + encryption: LdapEncryptionKind::from_enum_str(&val.encryption)?, + skip_tls_verify: val.skip_tls_verify, + email_attribute: val.email_attribute, + name_attribute: val.name_attribute, + created_at: val.created_at, + updated_at: val.updated_at, + }) + } +} + impl TryFrom for EmailSetting { type Error = anyhow::Error; @@ -447,6 +468,25 @@ impl DbEnum for OAuthProvider { } } +impl DbEnum for LdapEncryptionKind { + fn as_enum_str(&self) -> &'static str { + match self { + LdapEncryptionKind::None => "none", + LdapEncryptionKind::StartTLS => "starttls", + LdapEncryptionKind::LDAPS => "ldaps", + } + } + + fn from_enum_str(s: &str) -> anyhow::Result { + match s { + "none" => Ok(LdapEncryptionKind::None), + "starttls" => Ok(LdapEncryptionKind::StartTLS), + "ldaps" => Ok(LdapEncryptionKind::LDAPS), + _ => bail!("Invalid Ldap encryption kind"), + } + } +} + impl DbEnum for AuthMethod { fn as_enum_str(&self) -> &'static str { match self { diff --git a/ee/tabby-schema/src/schema/auth.rs b/ee/tabby-schema/src/schema/auth.rs index 404da5d72a8c..237e0d4538d4 100644 --- a/ee/tabby-schema/src/schema/auth.rs +++ b/ee/tabby-schema/src/schema/auth.rs @@ -67,6 +67,15 @@ pub struct TokenAuthInput { pub password: String, } +/// Input parameters for token_auth_ldap mutation +#[derive(Validate)] +pub struct TokenAuthLdapInput<'a> { + #[validate(length(min = 1, code = "user_id", message = "User ID should not be empty"))] + pub user_id: &'a str, + #[validate(length(min = 1, code = "password", message = "Password should not be empty"))] + pub password: &'a str, +} + /// Input parameters for register mutation /// `validate` attribute is used to validate the input parameters /// - `code` argument specifies which parameter causes the failure @@ -322,6 +331,35 @@ pub enum OAuthProvider { Gitlab, } +#[derive(GraphQLEnum, Clone, Serialize, Deserialize, PartialEq, Debug)] +pub enum AuthProviderKind { + OAuthGithub, + OAuthGoogle, + OAuthGitlab, + Ldap, +} + +impl From for AuthProvider { + fn from(provider: OAuthProvider) -> Self { + match provider { + OAuthProvider::Github => AuthProvider { + kind: AuthProviderKind::OAuthGithub, + }, + OAuthProvider::Google => AuthProvider { + kind: AuthProviderKind::OAuthGoogle, + }, + OAuthProvider::Gitlab => AuthProvider { + kind: AuthProviderKind::OAuthGitlab, + }, + } + } +} + +#[derive(GraphQLObject)] +pub struct AuthProvider { + pub kind: AuthProviderKind, +} + #[derive(GraphQLObject)] pub struct OAuthCredential { pub provider: OAuthProvider, @@ -348,6 +386,65 @@ pub struct UpdateOAuthCredentialInput { pub client_secret: Option, } +#[derive(GraphQLEnum, PartialEq, Debug)] +pub enum LdapEncryptionKind { + None, + StartTLS, + LDAPS, +} + +#[derive(GraphQLInputObject, Validate)] +pub struct UpdateLdapCredentialInput { + #[validate(length( + min = 1, + code = "host", + message = "host should not be empty and should be a valid hostname or IP address" + ))] + pub host: String, + pub port: i32, + + #[validate(length(min = 1, code = "bindDn", message = "bindDn cannot be empty"))] + pub bind_dn: String, + pub bind_password: Option, + + #[validate(length(min = 1, code = "baseDn", message = "baseDn cannot be empty"))] + pub base_dn: String, + #[validate(length( + min = 1, + code = "userFilter", + message = "userFilter cannot be empty, and should be in the format of `(uid=%s)`" + ))] + pub user_filter: String, + + pub encryption: LdapEncryptionKind, + pub skip_tls_verify: bool, + + #[validate(length( + min = 1, + code = "emailAttribute", + message = "emailAttribute cannot be empty" + ))] + pub email_attribute: String, + // if name_attribute is None, we will use username as name + pub name_attribute: Option, +} + +#[derive(GraphQLObject)] +pub struct LdapCredential { + pub host: String, + pub port: i32, + pub bind_dn: String, + pub base_dn: String, + pub user_filter: String, + pub encryption: LdapEncryptionKind, + pub skip_tls_verify: bool, + pub email_attribute: String, + pub name_attribute: Option, + + pub created_at: DateTime, + pub updated_at: DateTime, +} + #[async_trait] pub trait AuthenticationService: Send + Sync { async fn register( @@ -361,6 +458,8 @@ pub trait AuthenticationService: Send + Sync { async fn token_auth(&self, email: String, password: String) -> Result; + async fn token_auth_ldap(&self, email: &str, password: &str) -> Result; + async fn refresh_token(&self, refresh_token: String) -> Result; async fn verify_access_token(&self, access_token: &str) -> Result; async fn verify_auth_token(&self, token: &str) -> Result; @@ -414,8 +513,13 @@ pub trait AuthenticationService: Send + Sync { ) -> Result>; async fn update_oauth_credential(&self, input: UpdateOAuthCredentialInput) -> Result<()>; - async fn delete_oauth_credential(&self, provider: OAuthProvider) -> Result<()>; + + async fn read_ldap_credential(&self) -> Result>; + async fn test_ldap_connection(&self, input: UpdateLdapCredentialInput) -> Result<()>; + async fn update_ldap_credential(&self, input: UpdateLdapCredentialInput) -> Result<()>; + async fn delete_ldap_credential(&self) -> Result<()>; + async fn update_user_active(&self, id: &ID, active: bool) -> Result<()>; async fn update_user_role(&self, id: &ID, is_admin: bool) -> Result<()>; async fn update_user_avatar(&self, id: &ID, avatar: Option>) -> Result<()>; diff --git a/ee/tabby-schema/src/schema/mod.rs b/ee/tabby-schema/src/schema/mod.rs index c6c75e6e954d..550aa571be83 100644 --- a/ee/tabby-schema/src/schema/mod.rs +++ b/ee/tabby-schema/src/schema/mod.rs @@ -28,7 +28,8 @@ use async_openai::{ }, }; use auth::{ - AuthenticationService, Invitation, RefreshTokenResponse, RegisterResponse, TokenAuthResponse, + AuthProvider, AuthProviderKind, AuthenticationService, Invitation, LdapCredential, + RefreshTokenResponse, RegisterResponse, TokenAuthResponse, UpdateLdapCredentialInput, UserSecured, }; use base64::Engine; @@ -41,8 +42,10 @@ use juniper::{ graphql_object, graphql_subscription, graphql_value, FieldError, GraphQLEnum, GraphQLObject, IntoFieldError, Object, RootNode, ScalarValue, Value, ID, }; +use ldap3::result::LdapError; use notification::NotificationService; use repository::RepositoryGrepOutput; +use strum::IntoEnumIterator; use tabby_common::{ api::{code::CodeSearch, event::EventLogger}, config::CompletionConfig, @@ -145,6 +148,12 @@ pub enum CoreError { Other(#[from] anyhow::Error), } +impl From for CoreError { + fn from(err: LdapError) -> Self { + Self::Other(err.into()) + } +} + impl IntoFieldError for CoreError { fn into_field_error(self) -> FieldError { match self { @@ -429,6 +438,29 @@ impl Query { Ok(RepositoryGrepOutput { files, elapsed_ms }) } + async fn auth_providers(ctx: &Context) -> Result> { + let mut providers = vec![]; + + let auth = ctx.locator.auth(); + for x in OAuthProvider::iter() { + if auth + .read_oauth_credential(x.clone()) + .await + .is_ok_and(|x| x.is_some()) + { + providers.push(x.into()); + } + } + + if auth.read_ldap_credential().await.is_ok_and(|x| x.is_some()) { + providers.push(AuthProvider { + kind: AuthProviderKind::Ldap, + }); + } + + Ok(providers) + } + async fn oauth_credential( ctx: &Context, provider: OAuthProvider, @@ -442,6 +474,11 @@ impl Query { ctx.locator.auth().oauth_callback_url(provider).await } + async fn ldap_credential(ctx: &Context) -> Result> { + check_admin(ctx).await?; + ctx.locator.auth().read_ldap_credential().await + } + async fn server_info(ctx: &Context) -> Result { Ok(ServerInfo { is_admin_initialized: ctx.locator.auth().is_admin_initialized().await?, @@ -975,6 +1012,22 @@ impl Mutation { .await } + async fn token_auth_ldap( + ctx: &Context, + user_id: String, + password: String, + ) -> Result { + let input = auth::TokenAuthLdapInput { + user_id: &user_id, + password: &password, + }; + input.validate()?; + ctx.locator + .auth() + .token_auth_ldap(&user_id, &password) + .await + } + async fn verify_token(ctx: &Context, token: String) -> Result { ctx.locator.auth().verify_access_token(&token).await?; Ok(true) @@ -1058,6 +1111,31 @@ impl Mutation { Ok(true) } + async fn test_ldap_connection(ctx: &Context, input: UpdateLdapCredentialInput) -> Result { + check_admin(ctx).await?; + check_license(ctx, &[LicenseType::Enterprise]).await?; + ctx.locator.auth().test_ldap_connection(input).await?; + Ok(true) + } + + async fn update_ldap_credential( + ctx: &Context, + input: UpdateLdapCredentialInput, + ) -> Result { + check_admin(ctx).await?; + check_license(ctx, &[LicenseType::Enterprise]).await?; + input.validate()?; + + ctx.locator.auth().update_ldap_credential(input).await?; + Ok(true) + } + + async fn delete_ldap_credential(ctx: &Context) -> Result { + check_admin(ctx).await?; + ctx.locator.auth().delete_ldap_credential().await?; + Ok(true) + } + async fn update_email_setting(ctx: &Context, input: EmailSettingInput) -> Result { check_admin(ctx).await?; input.validate()?; diff --git a/ee/tabby-webserver/Cargo.toml b/ee/tabby-webserver/Cargo.toml index 75afa8fae5bf..9f30a82b2471 100644 --- a/ee/tabby-webserver/Cargo.toml +++ b/ee/tabby-webserver/Cargo.toml @@ -23,6 +23,7 @@ juniper.workspace = true juniper_axum = { version = "0.1", features = ["subscriptions"] } juniper_graphql_ws = "0.4" lazy_static.workspace = true +ldap3 = "0.11.0" lettre = { version = "0.11.3", features = ["tokio1", "tokio1-native-tls"] } mime_guess.workspace = true pin-project = "1.1.3" diff --git a/ee/tabby-webserver/src/ldap.rs b/ee/tabby-webserver/src/ldap.rs new file mode 100644 index 000000000000..9bd8eb1379f6 --- /dev/null +++ b/ee/tabby-webserver/src/ldap.rs @@ -0,0 +1,131 @@ +use anyhow::anyhow; +use async_trait::async_trait; +use ldap3::{drive, LdapConnAsync, LdapConnSettings, Scope, SearchEntry}; +use tabby_schema::{CoreError, Result}; + +#[async_trait] +pub trait LdapClient: Send + Sync { + async fn validate(&mut self, user: &str, password: &str) -> Result; +} + +pub fn new_ldap_client( + host: &str, + port: i64, + encryption: &str, + skip_verify_tls: bool, + bind_dn: String, + bind_password: &str, + base_dn: String, + user_filter: String, + email_attr: String, + name_attr: Option, +) -> impl LdapClient { + let mut settings = LdapConnSettings::new(); + if encryption == "starttls" { + settings = settings.set_starttls(true); + }; + if skip_verify_tls { + settings = settings.set_no_tls_verify(true); + }; + + let schema = if encryption == "ldaps" { + "ldaps" + } else { + "ldap" + }; + + LdapClientImpl { + address: format!("{}://{}:{}", schema, host, port), + bind_dn, + bind_password: bind_password.to_string(), + base_dn, + user_filter, + + email_attr, + name_attr, + + settings, + } +} + +pub struct LdapClientImpl { + address: String, + bind_dn: String, + bind_password: String, + base_dn: String, + user_filter: String, + + email_attr: String, + name_attr: Option, + + settings: LdapConnSettings, +} + +pub struct LdapUser { + pub email: String, + pub name: String, +} + +#[async_trait] +impl LdapClient for LdapClientImpl { + async fn validate(&mut self, user: &str, password: &str) -> Result { + let (connection, mut client) = + LdapConnAsync::with_settings(self.settings.clone(), &self.address).await?; + drive!(connection); + + // use bind_dn to search + let _res = client + .simple_bind(&self.bind_dn, &self.bind_password) + .await? + .success()?; + + let mut attrs = vec![&self.email_attr]; + if let Some(name_attr) = &self.name_attr { + attrs.push(name_attr); + } + let searched = client + .search( + &self.base_dn, + Scope::OneLevel, + &self.user_filter.replace("%s", user), + attrs, + ) + .await?; + + if let Some(entry) = searched.0.into_iter().next() { + let entry = SearchEntry::construct(entry); + let user_dn = entry.dn; + let email = entry + .attrs + .get(&self.email_attr) + .and_then(|v| v.first()) + .cloned() + .ok_or_else(|| CoreError::Other(anyhow!("email not found for user")))?; + let name = if let Some(name_attr) = &self.name_attr { + entry + .attrs + .get(name_attr) + .and_then(|v| v.first()) + .cloned() + .ok_or_else(|| CoreError::Other(anyhow!("name not found for user")))? + } else { + user.to_string() + }; + + client.simple_bind(&user_dn, password).await?.success()?; + + Ok(LdapUser { email, name }) + } else { + Err(ldap3::LdapError::LdapResult { + result: ldap3::LdapResult { + rc: 32, + matched: user.to_string(), + text: "User not found".to_string(), + refs: vec![], + ctrls: vec![], + }, + } + .into()) + } + } +} diff --git a/ee/tabby-webserver/src/lib.rs b/ee/tabby-webserver/src/lib.rs index 06560ebd4d16..17e692577d85 100644 --- a/ee/tabby-webserver/src/lib.rs +++ b/ee/tabby-webserver/src/lib.rs @@ -2,6 +2,7 @@ mod axum; mod hub; mod jwt; +mod ldap; mod oauth; mod path; mod rate_limit; diff --git a/ee/tabby-webserver/src/service/auth.rs b/ee/tabby-webserver/src/service/auth.rs index e8d13c17dacf..6fbd615640f8 100644 --- a/ee/tabby-webserver/src/service/auth.rs +++ b/ee/tabby-webserver/src/service/auth.rs @@ -12,9 +12,10 @@ use juniper::ID; use tabby_db::{DbConn, InvitationDAO}; use tabby_schema::{ auth::{ - AuthenticationService, Invitation, JWTPayload, OAuthCredential, OAuthError, OAuthProvider, - OAuthResponse, RefreshTokenResponse, RegisterResponse, RequestInvitationInput, - TokenAuthResponse, UpdateOAuthCredentialInput, UserSecured, + AuthenticationService, Invitation, JWTPayload, LdapCredential, OAuthCredential, OAuthError, + OAuthProvider, OAuthResponse, RefreshTokenResponse, RegisterResponse, + RequestInvitationInput, TokenAuthResponse, UpdateLdapCredentialInput, + UpdateOAuthCredentialInput, UserSecured, }, email::EmailService, is_demo_mode, @@ -29,6 +30,7 @@ use super::{graphql_pagination_to_filter, UserSecuredExt}; use crate::{ bail, jwt::{generate_jwt, validate_jwt}, + ldap::{self, LdapClient}, oauth::{self, OAuthClient}, }; @@ -320,6 +322,44 @@ impl AuthenticationService for AuthenticationServiceImpl { Ok(resp) } + async fn token_auth_ldap(&self, user_id: &str, password: &str) -> Result { + let license = self + .license + .read() + .await + .context("Failed to read license info")?; + + let credential = self.db.read_ldap_credential().await?; + if credential.is_none() { + bail!("LDAP is not configured"); + } + + let credential = credential.unwrap(); + let mut client = ldap::new_ldap_client( + credential.host.as_ref(), + credential.port, + credential.encryption.as_str(), + credential.skip_tls_verify, + credential.bind_dn, + &credential.bind_password, + credential.base_dn, + credential.user_filter, + credential.email_attribute, + credential.name_attribute, + ); + + ldap_login( + &mut client, + &self.db, + &*self.setting, + &license, + &*self.mail, + user_id, + password, + ) + .await + } + async fn refresh_token(&self, token: String) -> Result { let Some(refresh_token) = self.db.get_refresh_token(&token).await? else { bail!("Invalid refresh token"); @@ -551,6 +591,82 @@ impl AuthenticationService for AuthenticationServiceImpl { Ok(()) } + async fn read_ldap_credential(&self) -> Result> { + let credential = self.db.read_ldap_credential().await?; + match credential { + Some(c) => Ok(Some(c.try_into()?)), + None => Ok(None), + } + } + + async fn test_ldap_connection(&self, input: UpdateLdapCredentialInput) -> Result<()> { + let password = if let Some(password) = input.bind_password.as_deref() { + password + } else { + &self + .db + .read_ldap_credential() + .await? + .ok_or_else(|| anyhow!("LDAP password is not configured"))? + .bind_password + }; + let mut client = ldap::new_ldap_client( + input.host.as_ref(), + input.port as i64, + input.encryption.as_enum_str(), + input.skip_tls_verify, + input.bind_dn, + password, + input.base_dn, + input.user_filter, + input.email_attribute, + input.name_attribute, + ); + + if let Err(e) = client.validate("", "").await { + if e.to_string().contains("User not found") { + return Ok(()); + } else { + bail!("Failed to connect to LDAP server: {e}"); + } + } + + Ok(()) + } + + async fn update_ldap_credential(&self, input: UpdateLdapCredentialInput) -> Result<()> { + let password = if let Some(password) = input.bind_password.as_deref() { + password + } else { + &self + .db + .read_ldap_credential() + .await? + .ok_or_else(|| anyhow!("LDAP password is not configured"))? + .bind_password + }; + self.db + .update_ldap_credential( + &input.host, + input.port, + &input.bind_dn, + password, + &input.base_dn, + &input.user_filter, + input.encryption.as_enum_str(), + input.skip_tls_verify, + &input.email_attribute, + input.name_attribute.as_deref(), + ) + .await?; + Ok(()) + } + + async fn delete_ldap_credential(&self) -> Result<()> { + self.db.delete_ldap_credential().await?; + Ok(()) + } + async fn update_user_active(&self, id: &ID, active: bool) -> Result<()> { let id = id.as_rowid()?; let user = self.db.get_user(id).await?.context("User doesn't exits")?; @@ -580,6 +696,28 @@ impl AuthenticationService for AuthenticationServiceImpl { } } +async fn ldap_login( + client: &mut dyn LdapClient, + db: &DbConn, + setting: &dyn SettingService, + license: &LicenseInfo, + mail: &dyn EmailService, + user_id: &str, + password: &str, +) -> Result { + let user = client.validate(user_id, password).await?; + let user_id = get_or_create_sso_user(license, db, setting, mail, &user.email, &user.name) + .await + .map_err(|e| CoreError::Other(anyhow!("fail to get or create ldap user: {}", e)))?; + + let refresh_token = db.create_refresh_token(user_id).await?; + let access_token = generate_jwt(user_id.as_id()) + .map_err(|e| CoreError::Other(anyhow!("fail to create access_token: {}", e)))?; + + let resp = TokenAuthResponse::new(access_token, refresh_token); + Ok(resp) +} + async fn oauth_login( client: Arc, code: String, @@ -591,7 +729,7 @@ async fn oauth_login( let access_token = client.exchange_code_for_token(code).await?; let email = client.fetch_user_email(&access_token).await?; let name = client.fetch_user_full_name(&access_token).await?; - let user_id = get_or_create_oauth_user(license, db, setting, mail, &email, &name).await?; + let user_id = get_or_create_sso_user(license, db, setting, mail, &email, &name).await?; let refresh_token = db.create_refresh_token(user_id).await?; @@ -604,7 +742,7 @@ async fn oauth_login( Ok(resp) } -async fn get_or_create_oauth_user( +async fn get_or_create_sso_user( license: &LicenseInfo, db: &DbConn, setting: &dyn SettingService, @@ -638,7 +776,7 @@ async fn get_or_create_oauth_user( // it's ok to set password to null here, because // 1. both `register` & `token_auth` mutation will do input validation, so empty password won't be accepted // 2. `password_verify` will always return false for empty password hash read from user table - // so user created here is only able to login by github oauth, normal login won't work + // so user created here is only able to login by github oauth, or ldap, normal login won't work let res = db.create_user(email.to_owned(), None, false, name).await?; if let Err(e) = mail.send_signup(email.to_string()).await { @@ -704,6 +842,9 @@ fn password_verify(raw: &str, hash: &str) -> bool { #[cfg(test)] mod tests { + use tabby_schema::auth::LdapEncryptionKind; + + use crate::service::auth::testutils::FakeLdapClient; struct MockLicenseService { status: LicenseStatus, @@ -1039,7 +1180,7 @@ mod tests { service.db.update_user_active(id, false).await.unwrap(); let setting = service.setting; - let res = get_or_create_oauth_user( + let res = get_or_create_sso_user( &license, &service.db, &*setting, @@ -1056,7 +1197,7 @@ mod tests { .await .unwrap(); - let res = get_or_create_oauth_user( + let res = get_or_create_sso_user( &license, &service.db, &*setting, @@ -1074,7 +1215,7 @@ mod tests { tokio::time::sleep(Duration::milliseconds(50).to_std().unwrap()).await; assert_eq!(mail.list_mail().await[0].subject, "Welcome to Tabby!"); - let res = get_or_create_oauth_user( + let res = get_or_create_sso_user( &license, &service.db, &*setting, @@ -1091,7 +1232,7 @@ mod tests { .await .unwrap(); - let res = get_or_create_oauth_user( + let res = get_or_create_sso_user( &license, &service.db, &*setting, @@ -1540,6 +1681,67 @@ mod tests { assert!(service.refresh_token(token.refresh_token).await.is_err()); } + #[tokio::test] + async fn test_ldap_credential() { + let service = test_authentication_service().await; + service + .update_ldap_credential(UpdateLdapCredentialInput { + host: "ldap.example.com".into(), + port: 389, + bind_dn: "cn=admin,dc=example,dc=com".into(), + bind_password: Some("password".into()), + base_dn: "dc=example,dc=com".into(), + user_filter: "(&(objectClass=person)(uid=%s))".into(), + encryption: LdapEncryptionKind::None, + skip_tls_verify: false, + email_attribute: "mail".into(), + name_attribute: Some("cn".into()), + }) + .await + .unwrap(); + + // test the read_ldap_credential + let cred = service.read_ldap_credential().await.unwrap().unwrap(); + assert_eq!(cred.host, "ldap.example.com"); + assert_eq!(cred.port, 389); + assert_eq!(cred.bind_dn, "cn=admin,dc=example,dc=com"); + assert_eq!(cred.base_dn, "dc=example,dc=com"); + assert_eq!(cred.user_filter, "(&(objectClass=person)(uid=%s))"); + assert_eq!(cred.encryption, LdapEncryptionKind::None); + assert!(!cred.skip_tls_verify); + assert_eq!(cred.email_attribute, "mail"); + assert_eq!(cred.name_attribute, Some("cn".into())); + + service + .update_ldap_credential(UpdateLdapCredentialInput { + host: "ldap1.example1.com".into(), + port: 3890, + bind_dn: "cn=admin1,dc=example1,dc=com".into(), + bind_password: None, + base_dn: "dc=example1,dc=com".into(), + user_filter: "((uid=%s))".into(), + encryption: LdapEncryptionKind::None, + skip_tls_verify: true, + email_attribute: "email".into(), + name_attribute: Some("name".into()), + }) + .await + .unwrap(); + + // use db to verify the update and password sine it's not returned in service + let cred = service.db.read_ldap_credential().await.unwrap().unwrap(); + assert_eq!(cred.host, "ldap1.example1.com"); + assert_eq!(cred.port, 3890); + assert_eq!(cred.bind_dn, "cn=admin1,dc=example1,dc=com"); + assert_eq!(cred.bind_password, "password"); + assert_eq!(cred.base_dn, "dc=example1,dc=com"); + assert_eq!(cred.user_filter, "((uid=%s))"); + assert_eq!(cred.encryption, "none"); + assert!(cred.skip_tls_verify); + assert_eq!(cred.email_attribute, "email"); + assert_eq!(cred.name_attribute, Some("name".into())); + } + #[tokio::test] async fn test_oauth_credential() { let service = test_authentication_service().await; @@ -1562,6 +1764,71 @@ mod tests { assert_eq!(cred.client_secret, "secret"); } + #[tokio::test] + async fn test_ldap_login() { + let service = test_authentication_service().await; + let license = LicenseInfo { + r#type: LicenseType::Enterprise, + status: LicenseStatus::Ok, + seats: 1000, + seats_used: 0, + issued_at: None, + expires_at: None, + }; + + service + .create_invitation("user@example.com".into()) + .await + .unwrap(); + let mut ldap_client = FakeLdapClient { state: "" }; + + let response = ldap_login( + &mut ldap_client, + &service.db, + &*service.setting, + &license, + &*service.mail, + "user", + "password", + ) + .await + .unwrap(); + + assert!(!response.refresh_token.is_empty()); + } + + #[tokio::test] + async fn test_ldap_login_not_found() { + let service = test_authentication_service().await; + let license = LicenseInfo { + r#type: LicenseType::Enterprise, + status: LicenseStatus::Ok, + seats: 1000, + seats_used: 0, + issued_at: None, + expires_at: None, + }; + + service + .create_invitation("user@example.com".into()) + .await + .unwrap(); + let mut ldap_client = FakeLdapClient { state: "not_found" }; + + let response = ldap_login( + &mut ldap_client, + &service.db, + &*service.setting, + &license, + &*service.mail, + "user", + "password", + ) + .await; + + assert!(response.is_err()); + } + #[tokio::test] async fn test_oauth_login() { let service = test_authentication_service().await; diff --git a/ee/tabby-webserver/src/service/auth/testutils.rs b/ee/tabby-webserver/src/service/auth/testutils.rs index eca153cee36a..a44c77dc2765 100644 --- a/ee/tabby-webserver/src/service/auth/testutils.rs +++ b/ee/tabby-webserver/src/service/auth/testutils.rs @@ -3,14 +3,43 @@ use chrono::{Duration, Utc}; use juniper::ID; use tabby_schema::{ auth::{ - AuthenticationService, Invitation, JWTPayload, OAuthCredential, OAuthError, OAuthProvider, - OAuthResponse, RefreshTokenResponse, RegisterResponse, RequestInvitationInput, - TokenAuthResponse, UpdateOAuthCredentialInput, UserSecured, + AuthenticationService, Invitation, JWTPayload, LdapCredential, OAuthCredential, OAuthError, + OAuthProvider, OAuthResponse, RefreshTokenResponse, RegisterResponse, + RequestInvitationInput, TokenAuthResponse, UpdateLdapCredentialInput, + UpdateOAuthCredentialInput, UserSecured, }, Result, }; use tokio::task::JoinHandle; +use crate::ldap::{LdapClient, LdapUser}; + +pub struct FakeLdapClient<'a> { + pub state: &'a str, +} + +#[async_trait] +impl<'a> LdapClient for FakeLdapClient<'a> { + async fn validate(&mut self, user_id: &str, _password: &str) -> Result { + match self.state { + "not_found" => Err(ldap3::LdapError::LdapResult { + result: ldap3::LdapResult { + rc: 32, + matched: user_id.to_string(), + text: "User not found".to_string(), + refs: vec![], + ctrls: vec![], + }, + } + .into()), + _ => Ok(LdapUser { + email: "user@example.com".to_string(), + name: "Test User".to_string(), + }), + } + } +} + pub struct FakeAuthService { users: Vec, } @@ -80,6 +109,13 @@ impl AuthenticationService for FakeAuthService { )) } + async fn token_auth_ldap(&self, _user_id: &str, _password: &str) -> Result { + Ok(TokenAuthResponse::new( + "access_token".to_string(), + "refresh_token".to_string(), + )) + } + async fn refresh_token(&self, _token: String) -> Result { Ok(RefreshTokenResponse::new( "access_token".to_string(), @@ -173,6 +209,22 @@ impl AuthenticationService for FakeAuthService { Ok(vec![]) } + async fn read_ldap_credential(&self) -> Result> { + Ok(None) + } + + async fn test_ldap_connection(&self, _credential: UpdateLdapCredentialInput) -> Result<()> { + Ok(()) + } + + async fn update_ldap_credential(&self, _input: UpdateLdapCredentialInput) -> Result<()> { + Ok(()) + } + + async fn delete_ldap_credential(&self) -> Result<()> { + Ok(()) + } + async fn oauth( &self, _code: String, diff --git a/ee/tabby-webserver/src/service/mod.rs b/ee/tabby-webserver/src/service/mod.rs index f67734392f57..7e7cc2731e03 100644 --- a/ee/tabby-webserver/src/service/mod.rs +++ b/ee/tabby-webserver/src/service/mod.rs @@ -24,6 +24,8 @@ use answer::AnswerService; use anyhow::Context; use async_trait::async_trait; pub use auth::create as new_auth_service; +#[cfg(test)] +pub use auth::testutils::FakeAuthService; use axum::{ body::Body, http::{HeaderName, HeaderValue, Request, StatusCode},