bpf: x64: add JIT support for multi-function programs
authorAlexei Starovoitov <ast@fb.com>
Fri, 15 Dec 2017 01:55:15 +0000 (17:55 -0800)
committerDaniel Borkmann <daniel@iogearbox.net>
Sun, 17 Dec 2017 19:34:36 +0000 (20:34 +0100)
Typical JIT does several passes over bpf instructions to
compute total size and relative offsets of jumps and calls.
With multitple bpf functions calling each other all relative calls
will have invalid offsets intially therefore we need to additional
last pass over the program to emit calls with correct offsets.
For example in case of three bpf functions:
main:
  call foo
  call bpf_map_lookup
  exit
foo:
  call bar
  exit
bar:
  exit

We will call bpf_int_jit_compile() indepedently for main(), foo() and bar()
x64 JIT typically does 4-5 passes to converge.
After these initial passes the image for these 3 functions
will be good except call targets, since start addresses of
foo() and bar() are unknown when we were JITing main()
(note that call bpf_map_lookup will be resolved properly
during initial passes).
Once start addresses of 3 functions are known we patch
call_insn->imm to point to right functions and call
bpf_int_jit_compile() again which needs only one pass.
Additional safety checks are done to make sure this
last pass doesn't produce image that is larger or smaller
than previous pass.

When constant blinding is on it's applied to all functions
at the first pass, since doing it once again at the last
pass can change size of the JITed code.

Tested on x64 and arm64 hw with JIT on/off, blinding on/off.
x64 jits bpf-to-bpf calls correctly while arm64 falls back to interpreter.
All other JITs that support normal BPF_CALL will behave the same way
since bpf-to-bpf call is equivalent to bpf-to-kernel call from
JITs point of view.

Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Daniel Borkmann <daniel@iogearbox.net>
Signed-off-by: Daniel Borkmann <daniel@iogearbox.net>
arch/x86/net/bpf_jit_comp.c
include/linux/bpf.h
include/linux/bpf_verifier.h
include/linux/filter.h
kernel/bpf/core.c
kernel/bpf/syscall.c
kernel/bpf/verifier.c

index 68859b58ab8415e6d13d66dc25a33dcdd8bf08a7..87f214fbe66ec163d24b12b6defc7edab612ecc9 100644 (file)
@@ -1109,13 +1109,23 @@ common_load:
        return proglen;
 }
 
+struct x64_jit_data {
+       struct bpf_binary_header *header;
+       int *addrs;
+       u8 *image;
+       int proglen;
+       struct jit_context ctx;
+};
+
 struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 {
        struct bpf_binary_header *header = NULL;
        struct bpf_prog *tmp, *orig_prog = prog;
+       struct x64_jit_data *jit_data;
        int proglen, oldproglen = 0;
        struct jit_context ctx = {};
        bool tmp_blinded = false;
+       bool extra_pass = false;
        u8 *image = NULL;
        int *addrs;
        int pass;
@@ -1135,10 +1145,28 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                prog = tmp;
        }
 
+       jit_data = prog->aux->jit_data;
+       if (!jit_data) {
+               jit_data = kzalloc(sizeof(*jit_data), GFP_KERNEL);
+               if (!jit_data) {
+                       prog = orig_prog;
+                       goto out;
+               }
+               prog->aux->jit_data = jit_data;
+       }
+       addrs = jit_data->addrs;
+       if (addrs) {
+               ctx = jit_data->ctx;
+               oldproglen = jit_data->proglen;
+               image = jit_data->image;
+               header = jit_data->header;
+               extra_pass = true;
+               goto skip_init_addrs;
+       }
        addrs = kmalloc(prog->len * sizeof(*addrs), GFP_KERNEL);
        if (!addrs) {
                prog = orig_prog;
-               goto out;
+               goto out_addrs;
        }
 
        /* Before first pass, make a rough estimation of addrs[]
@@ -1149,6 +1177,7 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                addrs[i] = proglen;
        }
        ctx.cleanup_addr = proglen;
+skip_init_addrs:
 
        /* JITed image shrinks with every pass and the loop iterates
         * until the image stops shrinking. Very large bpf programs
@@ -1189,7 +1218,15 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
 
        if (image) {
                bpf_flush_icache(header, image + proglen);
-               bpf_jit_binary_lock_ro(header);
+               if (!prog->is_func || extra_pass) {
+                       bpf_jit_binary_lock_ro(header);
+               } else {
+                       jit_data->addrs = addrs;
+                       jit_data->ctx = ctx;
+                       jit_data->proglen = proglen;
+                       jit_data->image = image;
+                       jit_data->header = header;
+               }
                prog->bpf_func = (void *)image;
                prog->jited = 1;
                prog->jited_len = proglen;
@@ -1197,8 +1234,12 @@ struct bpf_prog *bpf_int_jit_compile(struct bpf_prog *prog)
                prog = orig_prog;
        }
 
+       if (!prog->is_func || extra_pass) {
 out_addrs:
-       kfree(addrs);
+               kfree(addrs);
+               kfree(jit_data);
+               prog->aux->jit_data = NULL;
+       }
 out:
        if (tmp_blinded)
                bpf_jit_prog_release_other(prog, prog == orig_prog ?
index 8935f6f63d5f0261cb1b5a22070747eaf6a68ca8..da54ef644fcdcefec46ce7b8002ae43b7ef37a66 100644 (file)
@@ -200,6 +200,9 @@ struct bpf_prog_aux {
        u32 max_ctx_offset;
        u32 stack_depth;
        u32 id;
+       u32 func_cnt;
+       struct bpf_prog **func;
+       void *jit_data; /* JIT specific data. arch dependent */
        struct latch_tree_node ksym_tnode;
        struct list_head ksym_lnode;
        const struct bpf_prog_ops *ops;
index 585d4e17ea88bd69eb4a3d79ae75e46b2c34508b..aaac589e490c9c19c0b0970a29092bcfb1632203 100644 (file)
@@ -143,6 +143,7 @@ struct bpf_insn_aux_data {
        union {
                enum bpf_reg_type ptr_type;     /* pointer type for load/store insns */
                struct bpf_map *map_ptr;        /* pointer for call insn into lookup_elem */
+               s32 call_imm;                   /* saved imm field of call insn */
        };
        int ctx_field_size; /* the ctx field size for load insn, maybe 0 */
        bool seen; /* this insn was processed by the verifier */
index 3d6edc34932c6c95393f05b4b3d69ce69413c1e5..e872b4ebaa5759f64650b194cc2730e1f952f4ec 100644 (file)
@@ -463,6 +463,8 @@ struct bpf_prog {
                                gpl_compatible:1, /* Is filter GPL compatible? */
                                cb_access:1,    /* Is control block accessed? */
                                dst_needed:1,   /* Do we need dst entry? */
+                               blinded:1,      /* Was blinded */
+                               is_func:1,      /* program is a bpf function */
                                kprobe_override:1; /* Do we override a kprobe? */
        enum bpf_prog_type      type;           /* Type of BPF program */
        u32                     len;            /* Number of filter blocks */
index bda911644b1cba4fe6b05033df9593886ff89f1f..768e0a02d8c8a94229f26f6bfda68fc3f7ce6383 100644 (file)
@@ -722,7 +722,7 @@ struct bpf_prog *bpf_jit_blind_constants(struct bpf_prog *prog)
        struct bpf_insn *insn;
        int i, rewritten;
 
-       if (!bpf_jit_blinding_enabled(prog))
+       if (!bpf_jit_blinding_enabled(prog) || prog->blinded)
                return prog;
 
        clone = bpf_prog_clone_create(prog, GFP_USER);
@@ -764,6 +764,7 @@ struct bpf_prog *bpf_jit_blind_constants(struct bpf_prog *prog)
                i        += insn_delta;
        }
 
+       clone->blinded = 1;
        return clone;
 }
 #endif /* CONFIG_BPF_JIT */
@@ -1629,11 +1630,19 @@ int bpf_prog_array_copy_info(struct bpf_prog_array __rcu *array,
 static void bpf_prog_free_deferred(struct work_struct *work)
 {
        struct bpf_prog_aux *aux;
+       int i;
 
        aux = container_of(work, struct bpf_prog_aux, work);
        if (bpf_prog_is_dev_bound(aux))
                bpf_prog_offload_destroy(aux->prog);
-       bpf_jit_free(aux->prog);
+       for (i = 0; i < aux->func_cnt; i++)
+               bpf_jit_free(aux->func[i]);
+       if (aux->func_cnt) {
+               kfree(aux->func);
+               bpf_prog_unlock_free(aux->prog);
+       } else {
+               bpf_jit_free(aux->prog);
+       }
 }
 
 /* Free internal BPF program */
index 2c4cfeaa8d5e785f16758be08cb8a462766363d9..e2e1c78ce1dcc6c2b40d0a0437ad9f62ac724f48 100644 (file)
@@ -1194,7 +1194,8 @@ static int bpf_prog_load(union bpf_attr *attr)
                goto free_used_maps;
 
        /* eBPF program is ready to be JITed */
-       prog = bpf_prog_select_runtime(prog, &err);
+       if (!prog->bpf_func)
+               prog = bpf_prog_select_runtime(prog, &err);
        if (err < 0)
                goto free_used_maps;
 
index 8e0e4cd0d5e459e4827c4e69ac70d77c1b9575d6..48b2901cf4830201e9d2502919be157b2949f024 100644 (file)
@@ -5012,12 +5012,138 @@ static int convert_ctx_accesses(struct bpf_verifier_env *env)
        return 0;
 }
 
+static int jit_subprogs(struct bpf_verifier_env *env)
+{
+       struct bpf_prog *prog = env->prog, **func, *tmp;
+       int i, j, subprog_start, subprog_end = 0, len, subprog;
+       struct bpf_insn *insn = prog->insnsi;
+       void *old_bpf_func;
+       int err = -ENOMEM;
+
+       if (env->subprog_cnt == 0)
+               return 0;
+
+       for (i = 0; i < prog->len; i++, insn++) {
+               if (insn->code != (BPF_JMP | BPF_CALL) ||
+                   insn->src_reg != BPF_PSEUDO_CALL)
+                       continue;
+               subprog = find_subprog(env, i + insn->imm + 1);
+               if (subprog < 0) {
+                       WARN_ONCE(1, "verifier bug. No program starts at insn %d\n",
+                                 i + insn->imm + 1);
+                       return -EFAULT;
+               }
+               /* temporarily remember subprog id inside insn instead of
+                * aux_data, since next loop will split up all insns into funcs
+                */
+               insn->off = subprog + 1;
+               /* remember original imm in case JIT fails and fallback
+                * to interpreter will be needed
+                */
+               env->insn_aux_data[i].call_imm = insn->imm;
+               /* point imm to __bpf_call_base+1 from JITs point of view */
+               insn->imm = 1;
+       }
+
+       func = kzalloc(sizeof(prog) * (env->subprog_cnt + 1), GFP_KERNEL);
+       if (!func)
+               return -ENOMEM;
+
+       for (i = 0; i <= env->subprog_cnt; i++) {
+               subprog_start = subprog_end;
+               if (env->subprog_cnt == i)
+                       subprog_end = prog->len;
+               else
+                       subprog_end = env->subprog_starts[i];
+
+               len = subprog_end - subprog_start;
+               func[i] = bpf_prog_alloc(bpf_prog_size(len), GFP_USER);
+               if (!func[i])
+                       goto out_free;
+               memcpy(func[i]->insnsi, &prog->insnsi[subprog_start],
+                      len * sizeof(struct bpf_insn));
+               func[i]->len = len;
+               func[i]->is_func = 1;
+               /* Use bpf_prog_F_tag to indicate functions in stack traces.
+                * Long term would need debug info to populate names
+                */
+               func[i]->aux->name[0] = 'F';
+               func[i]->aux->stack_depth = env->subprog_stack_depth[i];
+               func[i]->jit_requested = 1;
+               func[i] = bpf_int_jit_compile(func[i]);
+               if (!func[i]->jited) {
+                       err = -ENOTSUPP;
+                       goto out_free;
+               }
+               cond_resched();
+       }
+       /* at this point all bpf functions were successfully JITed
+        * now populate all bpf_calls with correct addresses and
+        * run last pass of JIT
+        */
+       for (i = 0; i <= env->subprog_cnt; i++) {
+               insn = func[i]->insnsi;
+               for (j = 0; j < func[i]->len; j++, insn++) {
+                       if (insn->code != (BPF_JMP | BPF_CALL) ||
+                           insn->src_reg != BPF_PSEUDO_CALL)
+                               continue;
+                       subprog = insn->off;
+                       insn->off = 0;
+                       insn->imm = (u64 (*)(u64, u64, u64, u64, u64))
+                               func[subprog]->bpf_func -
+                               __bpf_call_base;
+               }
+       }
+       for (i = 0; i <= env->subprog_cnt; i++) {
+               old_bpf_func = func[i]->bpf_func;
+               tmp = bpf_int_jit_compile(func[i]);
+               if (tmp != func[i] || func[i]->bpf_func != old_bpf_func) {
+                       verbose(env, "JIT doesn't support bpf-to-bpf calls\n");
+                       err = -EFAULT;
+                       goto out_free;
+               }
+               cond_resched();
+       }
+
+       /* finally lock prog and jit images for all functions and
+        * populate kallsysm
+        */
+       for (i = 0; i <= env->subprog_cnt; i++) {
+               bpf_prog_lock_ro(func[i]);
+               bpf_prog_kallsyms_add(func[i]);
+       }
+       prog->jited = 1;
+       prog->bpf_func = func[0]->bpf_func;
+       prog->aux->func = func;
+       prog->aux->func_cnt = env->subprog_cnt + 1;
+       return 0;
+out_free:
+       for (i = 0; i <= env->subprog_cnt; i++)
+               if (func[i])
+                       bpf_jit_free(func[i]);
+       kfree(func);
+       /* cleanup main prog to be interpreted */
+       prog->jit_requested = 0;
+       for (i = 0, insn = prog->insnsi; i < prog->len; i++, insn++) {
+               if (insn->code != (BPF_JMP | BPF_CALL) ||
+                   insn->src_reg != BPF_PSEUDO_CALL)
+                       continue;
+               insn->off = 0;
+               insn->imm = env->insn_aux_data[i].call_imm;
+       }
+       return err;
+}
+
 static int fixup_call_args(struct bpf_verifier_env *env)
 {
        struct bpf_prog *prog = env->prog;
        struct bpf_insn *insn = prog->insnsi;
        int i, depth;
 
+       if (env->prog->jit_requested)
+               if (jit_subprogs(env) == 0)
+                       return 0;
+
        for (i = 0; i < prog->len; i++, insn++) {
                if (insn->code != (BPF_JMP | BPF_CALL) ||
                    insn->src_reg != BPF_PSEUDO_CALL)