tls: Split decrypt_skb to two functions
authorBoris Pismenny <borisp@mellanox.com>
Fri, 13 Jul 2018 11:33:40 +0000 (14:33 +0300)
committerDavid S. Miller <davem@davemloft.net>
Mon, 16 Jul 2018 07:13:10 +0000 (00:13 -0700)
Previously, decrypt_skb also updated the TLS context.
Now, decrypt_skb only decrypts the payload using the current context,
while decrypt_skb_update also updates the state.

Later, in the tls_device Rx flow, we will use decrypt_skb directly.

Signed-off-by: Boris Pismenny <borisp@mellanox.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/tls.h
net/tls/tls_sw.c

index 5dcd808236a74629c6d5cb5494953e25b04df534..49b89221db43a4cd424cac40adbd426d97c301fd 100644 (file)
@@ -390,6 +390,8 @@ int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
                      unsigned char *record_type);
 void tls_register_device(struct tls_device *device);
 void tls_unregister_device(struct tls_device *device);
+int decrypt_skb(struct sock *sk, struct sk_buff *skb,
+               struct scatterlist *sgout);
 
 struct sk_buff *tls_validate_xmit_skb(struct sock *sk,
                                      struct net_device *dev,
index 7453f5ae081937ae84a5b3c8ae03120e04cb4d55..1d2271736717afb26646b6673be7e2081083fbf9 100644 (file)
@@ -53,7 +53,6 @@ static int tls_do_decryption(struct sock *sk,
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
-       struct strp_msg *rxm = strp_msg(skb);
        struct aead_request *aead_req;
 
        int ret;
@@ -71,18 +70,6 @@ static int tls_do_decryption(struct sock *sk,
 
        ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
 
-       if (ret < 0)
-               goto out;
-
-       rxm->offset += tls_ctx->rx.prepend_size;
-       rxm->full_len -= tls_ctx->rx.overhead_size;
-       tls_advance_record_sn(sk, &tls_ctx->rx);
-
-       ctx->decrypted = true;
-
-       ctx->saved_data_ready(sk);
-
-out:
        aead_request_free(aead_req);
        return ret;
 }
@@ -666,8 +653,29 @@ static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
        return skb;
 }
 
-static int decrypt_skb(struct sock *sk, struct sk_buff *skb,
-                      struct scatterlist *sgout)
+static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
+                             struct scatterlist *sgout)
+{
+       struct tls_context *tls_ctx = tls_get_ctx(sk);
+       struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
+       struct strp_msg *rxm = strp_msg(skb);
+       int err = 0;
+
+       err = decrypt_skb(sk, skb, sgout);
+       if (err < 0)
+               return err;
+
+       rxm->offset += tls_ctx->rx.prepend_size;
+       rxm->full_len -= tls_ctx->rx.overhead_size;
+       tls_advance_record_sn(sk, &tls_ctx->rx);
+       ctx->decrypted = true;
+       ctx->saved_data_ready(sk);
+
+       return err;
+}
+
+int decrypt_skb(struct sock *sk, struct sk_buff *skb,
+               struct scatterlist *sgout)
 {
        struct tls_context *tls_ctx = tls_get_ctx(sk);
        struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
@@ -812,7 +820,7 @@ int tls_sw_recvmsg(struct sock *sk,
                                if (err < 0)
                                        goto fallback_to_reg_recv;
 
-                               err = decrypt_skb(sk, skb, sgin);
+                               err = decrypt_skb_update(sk, skb, sgin);
                                for (; pages > 0; pages--)
                                        put_page(sg_page(&sgin[pages]));
                                if (err < 0) {
@@ -821,7 +829,7 @@ int tls_sw_recvmsg(struct sock *sk,
                                }
                        } else {
 fallback_to_reg_recv:
-                               err = decrypt_skb(sk, skb, NULL);
+                               err = decrypt_skb_update(sk, skb, NULL);
                                if (err < 0) {
                                        tls_err_abort(sk, EBADMSG);
                                        goto recv_end;
@@ -892,7 +900,7 @@ ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
        }
 
        if (!ctx->decrypted) {
-               err = decrypt_skb(sk, skb, NULL);
+               err = decrypt_skb_update(sk, skb, NULL);
 
                if (err < 0) {
                        tls_err_abort(sk, EBADMSG);