Skip to content

Commit

Permalink
ksmbd: fix racy issue while destroying session on multichannel
Browse files Browse the repository at this point in the history
After multi-channel connection with windows, Several channels of
session are connected. Among them, if there is a problem in one channel,
Windows connects again after disconnecting the channel. In this process,
the session is released and a kernel oop can occurs while processing
requests to other channels. When the channel is disconnected, if other
channels still exist in the session after deleting the channel from
the channel list in the session, the session should not be released.
Finally, the session will be released after all channels are disconnected.

Signed-off-by: Namjae Jeon <[email protected]>
  • Loading branch information
namjaejeon committed Jul 25, 2022
1 parent ca16111 commit 7202068
Show file tree
Hide file tree
Showing 13 changed files with 114 additions and 93 deletions.
56 changes: 31 additions & 25 deletions auth.c
Original file line number Diff line number Diff line change
Expand Up @@ -241,8 +241,8 @@ static int ksmbd_gen_sess_key(struct ksmbd_session *sess, char *hash,
return rc;
}

static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash,
char *dname)
static int calc_ntlmv2_hash(struct ksmbd_conn *conn, struct ksmbd_session *sess,
char *ntlmv2_hash, char *dname)
{
int ret, len, conv_len;
wchar_t *domain = NULL;
Expand Down Expand Up @@ -278,7 +278,7 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash,
}

conv_len = smb_strtoUTF16(uniname, user_name(sess->user), len,
sess->conn->local_nls);
conn->local_nls);
if (conv_len < 0 || conv_len > len) {
ret = -EINVAL;
goto out;
Expand All @@ -302,7 +302,7 @@ static int calc_ntlmv2_hash(struct ksmbd_session *sess, char *ntlmv2_hash,
}

conv_len = smb_strtoUTF16((__le16 *)domain, dname, len,
sess->conn->local_nls);
conn->local_nls);
if (conv_len < 0 || conv_len > len) {
ret = -EINVAL;
goto out;
Expand Down Expand Up @@ -372,8 +372,9 @@ int ksmbd_auth_ntlm(struct ksmbd_session *sess, char *pw_buf, char *cryptkey)
*
* Return: 0 on success, error number on error
*/
int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
int blen, char *domain_name, char *cryptkey)
int ksmbd_auth_ntlmv2(struct ksmbd_conn *conn, struct ksmbd_session *sess,
struct ntlmv2_resp *ntlmv2, int blen, char *domain_name,
char *cryptkey)
{
char ntlmv2_hash[CIFS_ENCPWD_SIZE];
char ntlmv2_rsp[CIFS_HMAC_MD5_HASH_SIZE];
Expand All @@ -387,7 +388,7 @@ int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
return -ENOMEM;
}

rc = calc_ntlmv2_hash(sess, ntlmv2_hash, domain_name);
rc = calc_ntlmv2_hash(conn, sess, ntlmv2_hash, domain_name);
if (rc) {
ksmbd_debug(AUTH, "could not get v2 hash rc %d\n", rc);
goto out;
Expand Down Expand Up @@ -614,7 +615,8 @@ int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
/* process NTLMv2 authentication */
ksmbd_debug(AUTH, "decode_ntlmssp_authenticate_blob dname%s\n",
domain_name);
ret = ksmbd_auth_ntlmv2(sess, (struct ntlmv2_resp *)((char *)authblob + nt_off),
ret = ksmbd_auth_ntlmv2(conn, sess,
(struct ntlmv2_resp *)((char *)authblob + nt_off),
nt_len - CIFS_ENCPWD_SIZE,
domain_name, conn->ntlmssp.cryptkey);
kfree(domain_name);
Expand Down Expand Up @@ -998,8 +1000,9 @@ struct derivation {
bool binding;
};

static int generate_key(struct ksmbd_session *sess, struct kvec label,
struct kvec context, __u8 *key, unsigned int key_size)
static int generate_key(struct ksmbd_conn *conn, struct ksmbd_session *sess,
struct kvec label, struct kvec context, __u8 *key,
unsigned int key_size)
{
unsigned char zero = 0x0;
__u8 i[4] = {0, 0, 0, 1};
Expand Down Expand Up @@ -1059,8 +1062,8 @@ static int generate_key(struct ksmbd_session *sess, struct kvec label,
goto smb3signkey_ret;
}

if (sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
if (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM)
rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), L256, 4);
else
rc = crypto_shash_update(CRYPTO_HMACSHA256(ctx), L128, 4);
Expand Down Expand Up @@ -1095,17 +1098,17 @@ static int generate_smb3signingkey(struct ksmbd_session *sess,
if (!chann)
return 0;

if (sess->conn->dialect >= SMB30_PROT_ID && signing->binding)
if (conn->dialect >= SMB30_PROT_ID && signing->binding)
key = chann->smb3signingkey;
else
key = sess->smb3signingkey;

rc = generate_key(sess, signing->label, signing->context, key,
rc = generate_key(conn, sess, signing->label, signing->context, key,
SMB3_SIGN_KEY_SIZE);
if (rc)
return rc;

if (!(sess->conn->dialect >= SMB30_PROT_ID && signing->binding))
if (!(conn->dialect >= SMB30_PROT_ID && signing->binding))
memcpy(chann->smb3signingkey, key, SMB3_SIGN_KEY_SIZE);

ksmbd_debug(AUTH, "dumping generated AES signing keys\n");
Expand Down Expand Up @@ -1159,30 +1162,31 @@ struct derivation_twin {
struct derivation decryption;
};

static int generate_smb3encryptionkey(struct ksmbd_session *sess,
static int generate_smb3encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess,
const struct derivation_twin *ptwin)
{
int rc;

rc = generate_key(sess, ptwin->encryption.label,
rc = generate_key(conn, sess, ptwin->encryption.label,
ptwin->encryption.context, sess->smb3encryptionkey,
SMB3_ENC_DEC_KEY_SIZE);
if (rc)
return rc;

rc = generate_key(sess, ptwin->decryption.label,
rc = generate_key(conn, sess, ptwin->decryption.label,
ptwin->decryption.context,
sess->smb3decryptionkey, SMB3_ENC_DEC_KEY_SIZE);
if (rc)
return rc;

ksmbd_debug(AUTH, "dumping generated AES encryption keys\n");
ksmbd_debug(AUTH, "Cipher type %d\n", sess->conn->cipher_type);
ksmbd_debug(AUTH, "Cipher type %d\n", conn->cipher_type);
ksmbd_debug(AUTH, "Session Id %llu\n", sess->id);
ksmbd_debug(AUTH, "Session Key %*ph\n",
SMB2_NTLMV2_SESSKEY_SIZE, sess->sess_key);
if (sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
sess->conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) {
if (conn->cipher_type == SMB2_ENCRYPTION_AES256_CCM ||
conn->cipher_type == SMB2_ENCRYPTION_AES256_GCM) {
ksmbd_debug(AUTH, "ServerIn Key %*ph\n",
SMB3_GCM256_CRYPTKEY_SIZE, sess->smb3encryptionkey);
ksmbd_debug(AUTH, "ServerOut Key %*ph\n",
Expand All @@ -1196,7 +1200,8 @@ static int generate_smb3encryptionkey(struct ksmbd_session *sess,
return 0;
}

int ksmbd_gen_smb30_encryptionkey(struct ksmbd_session *sess)
int ksmbd_gen_smb30_encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess)
{
struct derivation_twin twin;
struct derivation *d;
Expand All @@ -1213,10 +1218,11 @@ int ksmbd_gen_smb30_encryptionkey(struct ksmbd_session *sess)
d->context.iov_base = "ServerIn ";
d->context.iov_len = 10;

return generate_smb3encryptionkey(sess, &twin);
return generate_smb3encryptionkey(conn, sess, &twin);
}

int ksmbd_gen_smb311_encryptionkey(struct ksmbd_session *sess)
int ksmbd_gen_smb311_encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess)
{
struct derivation_twin twin;
struct derivation *d;
Expand All @@ -1233,7 +1239,7 @@ int ksmbd_gen_smb311_encryptionkey(struct ksmbd_session *sess)
d->context.iov_base = sess->Preauth_HashValue;
d->context.iov_len = 64;

return generate_smb3encryptionkey(sess, &twin);
return generate_smb3encryptionkey(conn, sess, &twin);
}

int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf,
Expand Down
11 changes: 7 additions & 4 deletions auth.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ int ksmbd_crypt_message(struct ksmbd_conn *conn, struct kvec *iov,
unsigned int nvec, int enc);
void ksmbd_copy_gss_neg_header(void *buf);
int ksmbd_auth_ntlm(struct ksmbd_session *sess, char *pw_buf, char *cryptkey);
int ksmbd_auth_ntlmv2(struct ksmbd_session *sess, struct ntlmv2_resp *ntlmv2,
int blen, char *domain_name, char *cryptkey);
int ksmbd_auth_ntlmv2(struct ksmbd_conn *conn, struct ksmbd_session *sess,
struct ntlmv2_resp *ntlmv2, int blen, char *domain_name,
char *cryptkey);
int ksmbd_decode_ntlmssp_auth_blob(struct authenticate_message *authblob,
int blob_len, struct ksmbd_conn *conn,
struct ksmbd_session *sess);
Expand All @@ -63,8 +64,10 @@ int ksmbd_gen_smb30_signingkey(struct ksmbd_session *sess,
struct ksmbd_conn *conn);
int ksmbd_gen_smb311_signingkey(struct ksmbd_session *sess,
struct ksmbd_conn *conn);
int ksmbd_gen_smb30_encryptionkey(struct ksmbd_session *sess);
int ksmbd_gen_smb311_encryptionkey(struct ksmbd_session *sess);
int ksmbd_gen_smb30_encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess);
int ksmbd_gen_smb311_encryptionkey(struct ksmbd_conn *conn,
struct ksmbd_session *sess);
int ksmbd_gen_preauth_integrity_hash(struct ksmbd_conn *conn, char *buf,
__u8 *pi_hash);
int ksmbd_gen_sd_hash(struct ksmbd_conn *conn, char *sd_buf, int len,
Expand Down
7 changes: 0 additions & 7 deletions connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,6 @@

#define KSMBD_SOCKET_BACKLOG 16

/*
* WARNING
*
* This is nothing but a HACK. Session status should move to channel
* or to session. As of now we have 1 tcp_conn : 1 ksmbd_session, but
* we need to change it to 1 tcp_conn : N ksmbd_sessions.
*/
enum {
KSMBD_SESS_NEW = 0,
KSMBD_SESS_GOOD,
Expand Down
5 changes: 3 additions & 2 deletions mgmt/tree_connect.c
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,8 @@
#include "user_session.h"

struct ksmbd_tree_conn_status
ksmbd_tree_conn_connect(struct ksmbd_session *sess, char *share_name)
ksmbd_tree_conn_connect(struct ksmbd_conn *conn, struct ksmbd_session *sess,
char *share_name)
{
struct ksmbd_tree_conn_status status = {-EINVAL, NULL};
struct ksmbd_tree_connect_response *resp = NULL;
Expand Down Expand Up @@ -45,7 +46,7 @@ ksmbd_tree_conn_connect(struct ksmbd_session *sess, char *share_name)
goto out_error;
}

peer_addr = KSMBD_TCP_PEER_SOCKADDR(sess->conn);
peer_addr = KSMBD_TCP_PEER_SOCKADDR(conn);
resp = ksmbd_ipc_tree_connect_request(sess,
sc,
tree_conn,
Expand Down
4 changes: 3 additions & 1 deletion mgmt/tree_connect.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

struct ksmbd_share_config;
struct ksmbd_user;
struct ksmbd_conn;

struct ksmbd_tree_connect {
int id;
Expand Down Expand Up @@ -40,7 +41,8 @@ static inline int test_tree_conn_flag(struct ksmbd_tree_connect *tree_conn,
struct ksmbd_session;

struct ksmbd_tree_conn_status
ksmbd_tree_conn_connect(struct ksmbd_session *sess, char *share_name);
ksmbd_tree_conn_connect(struct ksmbd_conn *conn, struct ksmbd_session *sess,
char *share_name);

int ksmbd_tree_conn_disconnect(struct ksmbd_session *sess,
struct ksmbd_tree_connect *tree_conn);
Expand Down
69 changes: 46 additions & 23 deletions mgmt/user_session.c
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,8 @@ void ksmbd_session_destroy(struct ksmbd_session *sess)
if (!sess)
return;

if (!atomic_dec_and_test(&sess->refcnt))
return;

#ifdef CONFIG_SMB_INSECURE_SERVER
if (IS_SMB2(sess->conn)) {
if (hash_hashed(&sess->hlist)) {
down_write(&sessions_table_lock);
hash_del(&sess->hlist);
up_write(&sessions_table_lock);
Expand Down Expand Up @@ -192,16 +189,58 @@ static struct ksmbd_session *__session_lookup(unsigned long long id)
int ksmbd_session_register(struct ksmbd_conn *conn,
struct ksmbd_session *sess)
{
sess->conn = conn;
sess->dialect = conn->dialect;
memcpy(sess->ClientGUID, conn->ClientGUID, SMB2_CLIENT_GUID_SIZE);
return xa_err(xa_store(&conn->sessions, sess->id, sess, GFP_KERNEL));
}

static int ksmbd_chann_del(struct ksmbd_conn *conn, struct ksmbd_session *sess)
{
struct channel *chann, *tmp;

write_lock(&sess->chann_lock);
list_for_each_entry_safe(chann, tmp, &sess->ksmbd_chann_list,
chann_list) {
if (chann->conn == conn) {
list_del(&chann->chann_list);
kfree(chann);
write_unlock(&sess->chann_lock);
return 0;
}
}
write_unlock(&sess->chann_lock);

return -ENOENT;
}

void ksmbd_sessions_deregister(struct ksmbd_conn *conn)
{
struct ksmbd_session *sess;
unsigned long id;

xa_for_each(&conn->sessions, id, sess) {
if (conn->binding) {
int bkt;

down_write(&sessions_table_lock);
hash_for_each(sessions_table, bkt, sess, hlist) {
if (!ksmbd_chann_del(conn, sess)) {
up_write(&sessions_table_lock);
goto sess_destroy;
}
}
up_write(&sessions_table_lock);
} else {
unsigned long id;

xa_for_each(&conn->sessions, id, sess) {
if (!ksmbd_chann_del(conn, sess))
goto sess_destroy;
}
}

return;

sess_destroy:
if (list_empty(&sess->ksmbd_chann_list)) {
xa_erase(&conn->sessions, sess->id);
ksmbd_session_destroy(sess);
}
Expand All @@ -213,27 +252,12 @@ struct ksmbd_session *ksmbd_session_lookup(struct ksmbd_conn *conn,
return xa_load(&conn->sessions, id);
}

int get_session(struct ksmbd_session *sess)
{
return atomic_inc_not_zero(&sess->refcnt);
}

void put_session(struct ksmbd_session *sess)
{
if (atomic_dec_and_test(&sess->refcnt))
pr_err("get/%s seems to be mismatched.", __func__);
}

struct ksmbd_session *ksmbd_session_lookup_slowpath(unsigned long long id)
{
struct ksmbd_session *sess;

down_read(&sessions_table_lock);
sess = __session_lookup(id);
if (sess) {
if (!get_session(sess))
sess = NULL;
}
up_read(&sessions_table_lock);

return sess;
Expand Down Expand Up @@ -326,7 +350,6 @@ static struct ksmbd_session *__session_create(int protocol)
INIT_LIST_HEAD(&sess->ksmbd_chann_list);
INIT_LIST_HEAD(&sess->rpc_handle_list);
sess->sequence_number = 1;
atomic_set(&sess->refcnt, 1);
rwlock_init(&sess->chann_lock);

switch (protocol) {
Expand Down
7 changes: 3 additions & 4 deletions mgmt/user_session.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ struct preauth_session {
struct ksmbd_session {
u64 id;

__u16 dialect;
char ClientGUID[SMB2_CLIENT_GUID_SIZE];

struct ksmbd_user *user;
struct ksmbd_conn *conn;
unsigned int sequence_number;
unsigned int flags;

Expand All @@ -62,7 +64,6 @@ struct ksmbd_session {
__u8 smb3signingkey[SMB3_SIGN_KEY_SIZE];

struct ksmbd_file_table file_table;
atomic_t refcnt;
};

static inline int test_session_flag(struct ksmbd_session *sess, int bit)
Expand Down Expand Up @@ -106,6 +107,4 @@ void ksmbd_release_tree_conn_id(struct ksmbd_session *sess, int id);
int ksmbd_session_rpc_open(struct ksmbd_session *sess, char *rpc_name);
void ksmbd_session_rpc_close(struct ksmbd_session *sess, int id);
int ksmbd_session_rpc_method(struct ksmbd_session *sess, int id);
int get_session(struct ksmbd_session *sess);
void put_session(struct ksmbd_session *sess);
#endif /* __USER_SESSION_MANAGEMENT_H__ */
Loading

0 comments on commit 7202068

Please sign in to comment.