bpf: Add a bpf_sock pointer to __sk_buff and a bpf_sk_fullsock helper
authorMartin KaFai Lau <kafai@fb.com>
Sun, 10 Feb 2019 07:22:20 +0000 (23:22 -0800)
committerAlexei Starovoitov <ast@kernel.org>
Mon, 11 Feb 2019 03:46:17 +0000 (19:46 -0800)
In kernel, it is common to check "skb->sk && sk_fullsock(skb->sk)"
before accessing the fields in sock.  For example, in __netdev_pick_tx:

static u16 __netdev_pick_tx(struct net_device *dev, struct sk_buff *skb,
    struct net_device *sb_dev)
{
/* ... */

struct sock *sk = skb->sk;

if (queue_index != new_index && sk &&
    sk_fullsock(sk) &&
    rcu_access_pointer(sk->sk_dst_cache))
sk_tx_queue_set(sk, new_index);

/* ... */

return queue_index;
}

This patch adds a "struct bpf_sock *sk" pointer to the "struct __sk_buff"
where a few of the convert_ctx_access() in filter.c has already been
accessing the skb->sk sock_common's fields,
e.g. sock_ops_convert_ctx_access().

"__sk_buff->sk" is a PTR_TO_SOCK_COMMON_OR_NULL in the verifier.
Some of the fileds in "bpf_sock" will not be directly
accessible through the "__sk_buff->sk" pointer.  It is limited
by the new "bpf_sock_common_is_valid_access()".
e.g. The existing "type", "protocol", "mark" and "priority" in bpf_sock
     are not allowed.

The newly added "struct bpf_sock *bpf_sk_fullsock(struct bpf_sock *sk)"
can be used to get a sk with all accessible fields in "bpf_sock".
This helper is added to both cg_skb and sched_(cls|act).

int cg_skb_foo(struct __sk_buff *skb) {
struct bpf_sock *sk;

sk = skb->sk;
if (!sk)
return 1;

sk = bpf_sk_fullsock(sk);
if (!sk)
return 1;

if (sk->family != AF_INET6 || sk->protocol != IPPROTO_TCP)
return 1;

/* some_traffic_shaping(); */

return 1;
}

(1) The sk is read only

(2) There is no new "struct bpf_sock_common" introduced.

(3) Future kernel sock's members could be added to bpf_sock only
    instead of repeatedly adding at multiple places like currently
    in bpf_sock_ops_md, bpf_sock_addr_md, sk_reuseport_md...etc.

(4) After "sk = skb->sk", the reg holding sk is in type
    PTR_TO_SOCK_COMMON_OR_NULL.

(5) After bpf_sk_fullsock(), the return type will be in type
    PTR_TO_SOCKET_OR_NULL which is the same as the return type of
    bpf_sk_lookup_xxx().

    However, bpf_sk_fullsock() does not take refcnt.  The
    acquire_reference_state() is only depending on the return type now.
    To avoid it, a new is_acquire_function() is checked before calling
    acquire_reference_state().

(6) The WARN_ON in "release_reference_state()" is no longer an
    internal verifier bug.

    When reg->id is not found in state->refs[], it means the
    bpf_prog does something wrong like
    "bpf_sk_release(bpf_sk_fullsock(skb->sk))" where reference has
    never been acquired by calling "bpf_sk_fullsock(skb->sk)".

    A -EINVAL and a verbose are done instead of WARN_ON.  A test is
    added to the test_verifier in a later patch.

    Since the WARN_ON in "release_reference_state()" is no longer
    needed, "__release_reference_state()" is folded into
    "release_reference_state()" also.

Acked-by: Alexei Starovoitov <ast@kernel.org>
Signed-off-by: Martin KaFai Lau <kafai@fb.com>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
include/linux/bpf.h
include/uapi/linux/bpf.h
kernel/bpf/verifier.c
net/core/filter.c

index bd169a7bcc93f251776f2274aad134e03084a9d4..a60463b45b542e35f155187ec4a1b4181d87687e 100644 (file)
@@ -194,6 +194,7 @@ enum bpf_arg_type {
        ARG_ANYTHING,           /* any (initialized) argument is ok */
        ARG_PTR_TO_SOCKET,      /* pointer to bpf_sock */
        ARG_PTR_TO_SPIN_LOCK,   /* pointer to bpf_spin_lock */
+       ARG_PTR_TO_SOCK_COMMON, /* pointer to sock_common */
 };
 
 /* type of values returned from helper functions */
@@ -256,6 +257,8 @@ enum bpf_reg_type {
        PTR_TO_FLOW_KEYS,        /* reg points to bpf_flow_keys */
        PTR_TO_SOCKET,           /* reg points to struct bpf_sock */
        PTR_TO_SOCKET_OR_NULL,   /* reg points to struct bpf_sock or NULL */
+       PTR_TO_SOCK_COMMON,      /* reg points to sock_common */
+       PTR_TO_SOCK_COMMON_OR_NULL, /* reg points to sock_common or NULL */
 };
 
 /* The information passed from prog-specific *_is_valid_access
@@ -920,6 +923,9 @@ void bpf_user_rnd_init_once(void);
 u64 bpf_user_rnd_u32(u64 r1, u64 r2, u64 r3, u64 r4, u64 r5);
 
 #if defined(CONFIG_NET)
+bool bpf_sock_common_is_valid_access(int off, int size,
+                                    enum bpf_access_type type,
+                                    struct bpf_insn_access_aux *info);
 bool bpf_sock_is_valid_access(int off, int size, enum bpf_access_type type,
                              struct bpf_insn_access_aux *info);
 u32 bpf_sock_convert_ctx_access(enum bpf_access_type type,
@@ -928,6 +934,12 @@ u32 bpf_sock_convert_ctx_access(enum bpf_access_type type,
                                struct bpf_prog *prog,
                                u32 *target_size);
 #else
+static inline bool bpf_sock_common_is_valid_access(int off, int size,
+                                                  enum bpf_access_type type,
+                                                  struct bpf_insn_access_aux *info)
+{
+       return false;
+}
 static inline bool bpf_sock_is_valid_access(int off, int size,
                                            enum bpf_access_type type,
                                            struct bpf_insn_access_aux *info)
index 1777fa0c61e4a2dd9f15b1908a91a8582c8cad98..5d79cba74ddc816d4b3932a601dbf4bd1a34fc23 100644 (file)
@@ -2329,6 +2329,14 @@ union bpf_attr {
  *             "**y**".
  *     Return
  *             0
+ *
+ * struct bpf_sock *bpf_sk_fullsock(struct bpf_sock *sk)
+ *     Description
+ *             This helper gets a **struct bpf_sock** pointer such
+ *             that all the fields in bpf_sock can be accessed.
+ *     Return
+ *             A **struct bpf_sock** pointer on success, or NULL in
+ *             case of failure.
  */
 #define __BPF_FUNC_MAPPER(FN)          \
        FN(unspec),                     \
@@ -2425,7 +2433,8 @@ union bpf_attr {
        FN(msg_pop_data),               \
        FN(rc_pointer_rel),             \
        FN(spin_lock),                  \
-       FN(spin_unlock),
+       FN(spin_unlock),                \
+       FN(sk_fullsock),
 
 /* integer value in 'imm' field of BPF_CALL instruction selects which helper
  * function eBPF program intends to call
@@ -2545,6 +2554,7 @@ struct __sk_buff {
        __u64 tstamp;
        __u32 wire_len;
        __u32 gso_segs;
+       __bpf_md_ptr(struct bpf_sock *, sk);
 };
 
 struct bpf_tunnel_key {
index 516dfc6d78de7e5fa57935839036eec8b1292fb1..b755d55a379128e2bb857f3c699dd7bad0e04a5e 100644 (file)
@@ -331,10 +331,17 @@ static bool type_is_pkt_pointer(enum bpf_reg_type type)
               type == PTR_TO_PACKET_META;
 }
 
+static bool type_is_sk_pointer(enum bpf_reg_type type)
+{
+       return type == PTR_TO_SOCKET ||
+               type == PTR_TO_SOCK_COMMON;
+}
+
 static bool reg_type_may_be_null(enum bpf_reg_type type)
 {
        return type == PTR_TO_MAP_VALUE_OR_NULL ||
-              type == PTR_TO_SOCKET_OR_NULL;
+              type == PTR_TO_SOCKET_OR_NULL ||
+              type == PTR_TO_SOCK_COMMON_OR_NULL;
 }
 
 static bool type_is_refcounted(enum bpf_reg_type type)
@@ -377,6 +384,12 @@ static bool is_release_function(enum bpf_func_id func_id)
        return func_id == BPF_FUNC_sk_release;
 }
 
+static bool is_acquire_function(enum bpf_func_id func_id)
+{
+       return func_id == BPF_FUNC_sk_lookup_tcp ||
+               func_id == BPF_FUNC_sk_lookup_udp;
+}
+
 /* string representation of 'enum bpf_reg_type' */
 static const char * const reg_type_str[] = {
        [NOT_INIT]              = "?",
@@ -392,6 +405,8 @@ static const char * const reg_type_str[] = {
        [PTR_TO_FLOW_KEYS]      = "flow_keys",
        [PTR_TO_SOCKET]         = "sock",
        [PTR_TO_SOCKET_OR_NULL] = "sock_or_null",
+       [PTR_TO_SOCK_COMMON]    = "sock_common",
+       [PTR_TO_SOCK_COMMON_OR_NULL] = "sock_common_or_null",
 };
 
 static char slot_type_char[] = {
@@ -618,13 +633,10 @@ static int acquire_reference_state(struct bpf_verifier_env *env, int insn_idx)
 }
 
 /* release function corresponding to acquire_reference_state(). Idempotent. */
-static int __release_reference_state(struct bpf_func_state *state, int ptr_id)
+static int release_reference_state(struct bpf_func_state *state, int ptr_id)
 {
        int i, last_idx;
 
-       if (!ptr_id)
-               return -EFAULT;
-
        last_idx = state->acquired_refs - 1;
        for (i = 0; i < state->acquired_refs; i++) {
                if (state->refs[i].id == ptr_id) {
@@ -636,21 +648,7 @@ static int __release_reference_state(struct bpf_func_state *state, int ptr_id)
                        return 0;
                }
        }
-       return -EFAULT;
-}
-
-/* variation on the above for cases where we expect that there must be an
- * outstanding reference for the specified ptr_id.
- */
-static int release_reference_state(struct bpf_verifier_env *env, int ptr_id)
-{
-       struct bpf_func_state *state = cur_func(env);
-       int err;
-
-       err = __release_reference_state(state, ptr_id);
-       if (WARN_ON_ONCE(err != 0))
-               verbose(env, "verifier internal error: can't release reference\n");
-       return err;
+       return -EINVAL;
 }
 
 static int transfer_reference_state(struct bpf_func_state *dst,
@@ -1209,6 +1207,8 @@ static bool is_spillable_regtype(enum bpf_reg_type type)
        case CONST_PTR_TO_MAP:
        case PTR_TO_SOCKET:
        case PTR_TO_SOCKET_OR_NULL:
+       case PTR_TO_SOCK_COMMON:
+       case PTR_TO_SOCK_COMMON_OR_NULL:
                return true;
        default:
                return false;
@@ -1647,6 +1647,7 @@ static int check_sock_access(struct bpf_verifier_env *env, int insn_idx,
        struct bpf_reg_state *regs = cur_regs(env);
        struct bpf_reg_state *reg = &regs[regno];
        struct bpf_insn_access_aux info = {};
+       bool valid;
 
        if (reg->smin_value < 0) {
                verbose(env, "R%d min value is negative, either use unsigned index or do a if (index >=0) check.\n",
@@ -1654,15 +1655,28 @@ static int check_sock_access(struct bpf_verifier_env *env, int insn_idx,
                return -EACCES;
        }
 
-       if (!bpf_sock_is_valid_access(off, size, t, &info)) {
-               verbose(env, "invalid bpf_sock access off=%d size=%d\n",
-                       off, size);
-               return -EACCES;
+       switch (reg->type) {
+       case PTR_TO_SOCK_COMMON:
+               valid = bpf_sock_common_is_valid_access(off, size, t, &info);
+               break;
+       case PTR_TO_SOCKET:
+               valid = bpf_sock_is_valid_access(off, size, t, &info);
+               break;
+       default:
+               valid = false;
        }
 
-       env->insn_aux_data[insn_idx].ctx_field_size = info.ctx_field_size;
 
-       return 0;
+       if (valid) {
+               env->insn_aux_data[insn_idx].ctx_field_size =
+                       info.ctx_field_size;
+               return 0;
+       }
+
+       verbose(env, "R%d invalid %s access off=%d size=%d\n",
+               regno, reg_type_str[reg->type], off, size);
+
+       return -EACCES;
 }
 
 static bool __is_pointer_value(bool allow_ptr_leaks,
@@ -1688,8 +1702,14 @@ static bool is_ctx_reg(struct bpf_verifier_env *env, int regno)
 {
        const struct bpf_reg_state *reg = reg_state(env, regno);
 
-       return reg->type == PTR_TO_CTX ||
-              reg->type == PTR_TO_SOCKET;
+       return reg->type == PTR_TO_CTX;
+}
+
+static bool is_sk_reg(struct bpf_verifier_env *env, int regno)
+{
+       const struct bpf_reg_state *reg = reg_state(env, regno);
+
+       return type_is_sk_pointer(reg->type);
 }
 
 static bool is_pkt_reg(struct bpf_verifier_env *env, int regno)
@@ -1800,6 +1820,9 @@ static int check_ptr_alignment(struct bpf_verifier_env *env,
        case PTR_TO_SOCKET:
                pointer_desc = "sock ";
                break;
+       case PTR_TO_SOCK_COMMON:
+               pointer_desc = "sock_common ";
+               break;
        default:
                break;
        }
@@ -2003,11 +2026,14 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
                         * PTR_TO_PACKET[_META,_END]. In the latter
                         * case, we know the offset is zero.
                         */
-                       if (reg_type == SCALAR_VALUE)
+                       if (reg_type == SCALAR_VALUE) {
                                mark_reg_unknown(env, regs, value_regno);
-                       else
+                       } else {
                                mark_reg_known_zero(env, regs,
                                                    value_regno);
+                               if (reg_type_may_be_null(reg_type))
+                                       regs[value_regno].id = ++env->id_gen;
+                       }
                        regs[value_regno].type = reg_type;
                }
 
@@ -2053,9 +2079,10 @@ static int check_mem_access(struct bpf_verifier_env *env, int insn_idx, u32 regn
                err = check_flow_keys_access(env, off, size);
                if (!err && t == BPF_READ && value_regno >= 0)
                        mark_reg_unknown(env, regs, value_regno);
-       } else if (reg->type == PTR_TO_SOCKET) {
+       } else if (type_is_sk_pointer(reg->type)) {
                if (t == BPF_WRITE) {
-                       verbose(env, "cannot write into socket\n");
+                       verbose(env, "R%d cannot write into %s\n",
+                               regno, reg_type_str[reg->type]);
                        return -EACCES;
                }
                err = check_sock_access(env, insn_idx, regno, off, size, t);
@@ -2102,7 +2129,8 @@ static int check_xadd(struct bpf_verifier_env *env, int insn_idx, struct bpf_ins
 
        if (is_ctx_reg(env, insn->dst_reg) ||
            is_pkt_reg(env, insn->dst_reg) ||
-           is_flow_key_reg(env, insn->dst_reg)) {
+           is_flow_key_reg(env, insn->dst_reg) ||
+           is_sk_reg(env, insn->dst_reg)) {
                verbose(env, "BPF_XADD stores into R%d %s is not allowed\n",
                        insn->dst_reg,
                        reg_type_str[reg_state(env, insn->dst_reg)->type]);
@@ -2369,6 +2397,11 @@ static int check_func_arg(struct bpf_verifier_env *env, u32 regno,
                err = check_ctx_reg(env, reg, regno);
                if (err < 0)
                        return err;
+       } else if (arg_type == ARG_PTR_TO_SOCK_COMMON) {
+               expected_type = PTR_TO_SOCK_COMMON;
+               /* Any sk pointer can be ARG_PTR_TO_SOCK_COMMON */
+               if (!type_is_sk_pointer(type))
+                       goto err_type;
        } else if (arg_type == ARG_PTR_TO_SOCKET) {
                expected_type = PTR_TO_SOCKET;
                if (type != expected_type)
@@ -2783,7 +2816,7 @@ static int release_reference(struct bpf_verifier_env *env,
        for (i = 0; i <= vstate->curframe; i++)
                release_reg_references(env, vstate->frame[i], meta->ptr_id);
 
-       return release_reference_state(env, meta->ptr_id);
+       return release_reference_state(cur_func(env), meta->ptr_id);
 }
 
 static int check_func_call(struct bpf_verifier_env *env, struct bpf_insn *insn,
@@ -3049,8 +3082,11 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
                }
        } else if (is_release_function(func_id)) {
                err = release_reference(env, &meta);
-               if (err)
+               if (err) {
+                       verbose(env, "func %s#%d reference has not been acquired before\n",
+                               func_id_name(func_id), func_id);
                        return err;
+               }
        }
 
        regs = cur_regs(env);
@@ -3099,12 +3135,19 @@ static int check_helper_call(struct bpf_verifier_env *env, int func_id, int insn
                        regs[BPF_REG_0].id = ++env->id_gen;
                }
        } else if (fn->ret_type == RET_PTR_TO_SOCKET_OR_NULL) {
-               int id = acquire_reference_state(env, insn_idx);
-               if (id < 0)
-                       return id;
                mark_reg_known_zero(env, regs, BPF_REG_0);
                regs[BPF_REG_0].type = PTR_TO_SOCKET_OR_NULL;
-               regs[BPF_REG_0].id = id;
+               if (is_acquire_function(func_id)) {
+                       int id = acquire_reference_state(env, insn_idx);
+
+                       if (id < 0)
+                               return id;
+                       /* For release_reference() */
+                       regs[BPF_REG_0].id = id;
+               } else {
+                       /* For mark_ptr_or_null_reg() */
+                       regs[BPF_REG_0].id = ++env->id_gen;
+               }
        } else {
                verbose(env, "unknown return type %d of func %s#%d\n",
                        fn->ret_type, func_id_name(func_id), func_id);
@@ -3364,6 +3407,8 @@ static int adjust_ptr_min_max_vals(struct bpf_verifier_env *env,
        case PTR_TO_PACKET_END:
        case PTR_TO_SOCKET:
        case PTR_TO_SOCKET_OR_NULL:
+       case PTR_TO_SOCK_COMMON:
+       case PTR_TO_SOCK_COMMON_OR_NULL:
                verbose(env, "R%d pointer arithmetic on %s prohibited\n",
                        dst, reg_type_str[ptr_reg->type]);
                return -EACCES;
@@ -4597,6 +4642,8 @@ static void mark_ptr_or_null_reg(struct bpf_func_state *state,
                        }
                } else if (reg->type == PTR_TO_SOCKET_OR_NULL) {
                        reg->type = PTR_TO_SOCKET;
+               } else if (reg->type == PTR_TO_SOCK_COMMON_OR_NULL) {
+                       reg->type = PTR_TO_SOCK_COMMON;
                }
                if (is_null || !(reg_is_refcounted(reg) ||
                                 reg_may_point_to_spin_lock(reg))) {
@@ -4621,7 +4668,7 @@ static void mark_ptr_or_null_regs(struct bpf_verifier_state *vstate, u32 regno,
        int i, j;
 
        if (reg_is_refcounted_or_null(&regs[regno]) && is_null)
-               __release_reference_state(state, id);
+               release_reference_state(state, id);
 
        for (i = 0; i < MAX_BPF_REG; i++)
                mark_ptr_or_null_reg(state, &regs[i], id, is_null);
@@ -5790,6 +5837,8 @@ static bool regsafe(struct bpf_reg_state *rold, struct bpf_reg_state *rcur,
        case PTR_TO_FLOW_KEYS:
        case PTR_TO_SOCKET:
        case PTR_TO_SOCKET_OR_NULL:
+       case PTR_TO_SOCK_COMMON:
+       case PTR_TO_SOCK_COMMON_OR_NULL:
                /* Only valid matches are exact, which memcmp() above
                 * would have accepted
                 */
@@ -6110,6 +6159,8 @@ static bool reg_type_mismatch_ok(enum bpf_reg_type type)
        case PTR_TO_CTX:
        case PTR_TO_SOCKET:
        case PTR_TO_SOCKET_OR_NULL:
+       case PTR_TO_SOCK_COMMON:
+       case PTR_TO_SOCK_COMMON_OR_NULL:
                return false;
        default:
                return true;
@@ -7112,6 +7163,7 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
                        convert_ctx_access = ops->convert_ctx_access;
                        break;
                case PTR_TO_SOCKET:
+               case PTR_TO_SOCK_COMMON:
                        convert_ctx_access = bpf_sock_convert_ctx_access;
                        break;
                default:
index 3a49f68eda10904ca9bd4c4ef3f8bd55f14bc54e..401d2e0aebf873f5acdf4eb3db408f70d2c47120 100644 (file)
@@ -1793,6 +1793,20 @@ static const struct bpf_func_proto bpf_skb_pull_data_proto = {
        .arg2_type      = ARG_ANYTHING,
 };
 
+BPF_CALL_1(bpf_sk_fullsock, struct sock *, sk)
+{
+       sk = sk_to_full_sk(sk);
+
+       return sk_fullsock(sk) ? (unsigned long)sk : (unsigned long)NULL;
+}
+
+static const struct bpf_func_proto bpf_sk_fullsock_proto = {
+       .func           = bpf_sk_fullsock,
+       .gpl_only       = false,
+       .ret_type       = RET_PTR_TO_SOCKET_OR_NULL,
+       .arg1_type      = ARG_PTR_TO_SOCK_COMMON,
+};
+
 static inline int sk_skb_try_make_writable(struct sk_buff *skb,
                                           unsigned int write_len)
 {
@@ -5406,6 +5420,8 @@ cg_skb_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
        switch (func_id) {
        case BPF_FUNC_get_local_storage:
                return &bpf_get_local_storage_proto;
+       case BPF_FUNC_sk_fullsock:
+               return &bpf_sk_fullsock_proto;
        default:
                return sk_filter_func_proto(func_id, prog);
        }
@@ -5477,6 +5493,8 @@ tc_cls_act_func_proto(enum bpf_func_id func_id, const struct bpf_prog *prog)
                return &bpf_get_socket_uid_proto;
        case BPF_FUNC_fib_lookup:
                return &bpf_skb_fib_lookup_proto;
+       case BPF_FUNC_sk_fullsock:
+               return &bpf_sk_fullsock_proto;
 #ifdef CONFIG_XFRM
        case BPF_FUNC_skb_get_xfrm_state:
                return &bpf_skb_get_xfrm_state_proto;
@@ -5764,6 +5782,11 @@ static bool bpf_skb_is_valid_access(int off, int size, enum bpf_access_type type
                if (size != sizeof(__u64))
                        return false;
                break;
+       case offsetof(struct __sk_buff, sk):
+               if (type == BPF_WRITE || size != sizeof(__u64))
+                       return false;
+               info->reg_type = PTR_TO_SOCK_COMMON_OR_NULL;
+               break;
        default:
                /* Only narrow read access allowed for now. */
                if (type == BPF_WRITE) {
@@ -5950,6 +5973,18 @@ static bool __sock_filter_check_size(int off, int size,
        return size == size_default;
 }
 
+bool bpf_sock_common_is_valid_access(int off, int size,
+                                    enum bpf_access_type type,
+                                    struct bpf_insn_access_aux *info)
+{
+       switch (off) {
+       case bpf_ctx_range_till(struct bpf_sock, type, priority):
+               return false;
+       default:
+               return bpf_sock_is_valid_access(off, size, type, info);
+       }
+}
+
 bool bpf_sock_is_valid_access(int off, int size, enum bpf_access_type type,
                              struct bpf_insn_access_aux *info)
 {
@@ -6748,6 +6783,13 @@ static u32 bpf_convert_ctx_access(enum bpf_access_type type,
                off += offsetof(struct qdisc_skb_cb, pkt_len);
                *target_size = 4;
                *insn++ = BPF_LDX_MEM(BPF_W, si->dst_reg, si->src_reg, off);
+               break;
+
+       case offsetof(struct __sk_buff, sk):
+               *insn++ = BPF_LDX_MEM(BPF_FIELD_SIZEOF(struct sk_buff, sk),
+                                     si->dst_reg, si->src_reg,
+                                     offsetof(struct sk_buff, sk));
+               break;
        }
 
        return insn - insn_buf;