From 859751fe977c173aa5623f5a6c3014c8699d2402 Mon Sep 17 00:00:00 2001 From: Felix Fietkau Date: Wed, 29 Jun 2022 20:12:48 +0200 Subject: [PATCH] vxlan: add bpf program to fix up tcp mss values Signed-off-by: Felix Fietkau --- CMakeLists.txt | 6 +- bpf.c | 110 +++++++++++++++++++++++++ main.c | 6 +- mss-bpf.c | 214 +++++++++++++++++++++++++++++++++++++++++++++++++ rtnl.c | 94 ++++++++++++++++++++++ unetd.h | 4 + utils.h | 5 ++ vxlan.c | 118 +++++---------------------- 8 files changed, 455 insertions(+), 102 deletions(-) create mode 100644 bpf.c create mode 100644 mss-bpf.c create mode 100644 rtnl.c diff --git a/CMakeLists.txt b/CMakeLists.txt index 73c4cfa..4ac58f1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,9 +17,11 @@ FIND_LIBRARY(libjson NAMES json-c json) OPTION(UBUS_SUPPORT "enable ubus support" ON) IF(CMAKE_SYSTEM_NAME STREQUAL "Linux") FIND_LIBRARY(nl nl-tiny) - SET(SOURCES ${SOURCES} wg-linux.c vxlan.c) + find_library(bpf NAMES bpf) + SET(SOURCES ${SOURCES} wg-linux.c vxlan.c bpf.c rtnl.c) ELSE() SET(nl "") + SET(bpf "") ENDIF() IF(UBUS_SUPPORT) SET(SOURCES ${SOURCES} ubus.c) @@ -30,7 +32,7 @@ ELSE() ENDIF() ADD_EXECUTABLE(unetd ${SOURCES}) -TARGET_LINK_LIBRARIES(unetd ubox ${ubus} blobmsg_json ${libjson} ${nl}) +TARGET_LINK_LIBRARIES(unetd ubox ${ubus} blobmsg_json ${libjson} ${nl} ${bpf}) INSTALL(TARGETS unetd RUNTIME DESTINATION sbin diff --git a/bpf.c b/bpf.c new file mode 100644 index 0000000..d0ad683 --- /dev/null +++ b/bpf.c @@ -0,0 +1,110 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * Copyright (C) 2022 Felix Fietkau + */ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include "unetd.h" + +static int unetd_bpf_pr(enum libbpf_print_level level, const char *format, + va_list args) +{ + return vfprintf(stderr, format, args); +} + +static void unetd_init_env(void) +{ + struct rlimit limit = { + .rlim_cur = RLIM_INFINITY, + .rlim_max = RLIM_INFINITY, + }; + + setrlimit(RLIMIT_MEMLOCK, &limit); +} + +static void +unetd_set_prog_mtu(struct bpf_object *obj, uint32_t mtu) +{ + struct bpf_map *map = NULL; + + while ((map = bpf_object__next_map(obj, map)) != NULL) { + if (!strstr(bpf_map__name(map), ".rodata")) + continue; + + bpf_map__set_initial_value(map, &mtu, sizeof(mtu)); + } +} + +static int +unetd_attach_bpf_prog(int ifindex, int fd, bool egress) +{ + DECLARE_LIBBPF_OPTS(bpf_tc_hook, hook, + .attach_point = egress ? BPF_TC_EGRESS : BPF_TC_INGRESS, + .ifindex = ifindex); + DECLARE_LIBBPF_OPTS(bpf_tc_opts, attach_tc, + .flags = BPF_TC_F_REPLACE, + .handle = 1, + .prog_fd = fd, + .priority = UNETD_MSS_PRIO_BASE); + + bpf_tc_hook_create(&hook); + + return bpf_tc_attach(&hook, &attach_tc); +} + +int unetd_attach_mssfix(int ifindex, int mtu) +{ + struct bpf_program *prog; + struct bpf_object *obj; + int prog_fd; + int ret = -1; + + if (rtnl_init()) + return -1; + + unetd_init_env(); + libbpf_set_print(unetd_bpf_pr); + + obj = bpf_object__open_file(mssfix_path, NULL); + if (libbpf_get_error(obj)) { + perror("bpf_object__open_file"); + goto out; + } + + prog = bpf_object__find_program_by_name(obj, "mssfix"); + if (!prog) { + perror("bpf_object__find_program_by_name"); + goto out; + } + + bpf_program__set_type(prog, BPF_PROG_TYPE_SCHED_CLS); + + unetd_set_prog_mtu(obj, mtu); + + if (bpf_object__load(obj)) { + perror("bpf_object__load"); + goto out; + } + + prog_fd = bpf_program__fd(prog); + if (unetd_attach_bpf_prog(ifindex, prog_fd, true) || + unetd_attach_bpf_prog(ifindex, prog_fd, false)) + goto out; + + ret = 0; + +out: + bpf_object__close(obj); + + return ret; +} diff --git a/main.c b/main.c index 74fe964..3bd7744 100644 --- a/main.c +++ b/main.c @@ -16,6 +16,7 @@ struct cmdline_network { static struct cmdline_network *cmd_nets; static const char *hosts_file; +const char *mssfix_path = UNETD_MSS_BPF_PATH; bool dummy_mode; bool debug; @@ -97,7 +98,7 @@ int main(int argc, char **argv) struct cmdline_network *net; int ch; - while ((ch = getopt(argc, argv, "Ddh:N:")) != -1) { + while ((ch = getopt(argc, argv, "Ddh:M:N:")) != -1) { switch (ch) { case 'd': debug = true; @@ -114,6 +115,9 @@ int main(int argc, char **argv) net->data = optarg; cmd_nets = net; break; + case 'M': + mssfix_path = optarg; + break; } } diff --git a/mss-bpf.c b/mss-bpf.c new file mode 100644 index 0000000..5f5cb61 --- /dev/null +++ b/mss-bpf.c @@ -0,0 +1,214 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * Copyright (C) 2021 Felix Fietkau + */ +#define KBUILD_MODNAME "foo" +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +const volatile static uint32_t mtu = 1420; + +static __always_inline int proto_is_vlan(__u16 h_proto) +{ + return !!(h_proto == bpf_htons(ETH_P_8021Q) || + h_proto == bpf_htons(ETH_P_8021AD)); +} + +static __always_inline int proto_is_ip(__u16 h_proto) +{ + return !!(h_proto == bpf_htons(ETH_P_IP) || + h_proto == bpf_htons(ETH_P_IPV6)); +} + +static __always_inline void *skb_ptr(struct __sk_buff *skb, __u32 offset) +{ + void *start = (void *)(unsigned long long)skb->data; + + return start + offset; +} + +static __always_inline void *skb_end_ptr(struct __sk_buff *skb) +{ + return (void *)(unsigned long long)skb->data_end; +} + +static __always_inline int skb_check(struct __sk_buff *skb, void *ptr) +{ + if (ptr > skb_end_ptr(skb)) + return -1; + + return 0; +} + +static __always_inline int +parse_ethernet(struct __sk_buff *skb, __u32 *offset) +{ + struct ethhdr *eth; + __u16 h_proto; + int i; + + eth = skb_ptr(skb, *offset); + if (skb_check(skb, eth + 1)) + return -1; + + h_proto = eth->h_proto; + *offset += sizeof(*eth); + +#pragma unroll + for (i = 0; i < 2; i++) { + struct vlan_hdr *vlh = skb_ptr(skb, *offset); + + if (!proto_is_vlan(h_proto)) + break; + + if (skb_check(skb, vlh + 1)) + return -1; + + h_proto = vlh->h_vlan_encapsulated_proto; + *offset += sizeof(*vlh); + } + + return h_proto; +} + +static __always_inline int +parse_ipv4(struct __sk_buff *skb, __u32 *offset) +{ + struct iphdr *iph; + int hdr_len; + + iph = skb_ptr(skb, *offset); + if (skb_check(skb, iph + 1)) + return -1; + + hdr_len = iph->ihl * 4; + if (bpf_skb_pull_data(skb, *offset + hdr_len + sizeof(struct tcphdr) + 20)) + return -1; + + iph = skb_ptr(skb, *offset); + *offset += hdr_len; + + if (skb_check(skb, (void *)(iph + 1))) + return -1; + + return READ_ONCE(iph->protocol); +} + +static __always_inline bool +parse_ipv6(struct __sk_buff *skb, __u32 *offset) +{ + struct ipv6hdr *iph; + + if (bpf_skb_pull_data(skb, *offset + sizeof(*iph) + sizeof(struct tcphdr) + 20)) + return -1; + + iph = skb_ptr(skb, *offset); + *offset += sizeof(*iph); + + if (skb_check(skb, (void *)(iph + 1))) + return -1; + + return READ_ONCE(iph->nexthdr); +} + +static inline unsigned int +optlen(const u_int8_t *opt) +{ + if (opt[0] <= TCPOPT_NOP || opt[1] == 0) + return 1; + + return opt[1]; +} + +static __always_inline void +fixup_tcp(struct __sk_buff *skb, __u32 offset, __u16 mss) +{ + struct tcphdr *tcph; + __u16 oldmss; + __u8 *opt; + u8 flags; + int hdrlen; + int i; + + tcph = skb_ptr(skb, offset); + if (skb_check(skb, tcph + 1)) + return; + + flags = tcp_flag_byte(tcph); + if (!(flags & TCPHDR_SYN)) + return; + + hdrlen = tcph->doff * 4; + if (hdrlen <= sizeof(struct tcphdr)) + return; + + hdrlen += offset; + offset += sizeof(*tcph); + +#pragma unroll + for (i = 0; i < 5; i++) { + unsigned int len; + + if (offset >= hdrlen) + return; + + opt = skb_ptr(skb, offset); + if (skb_check(skb, opt + TCPOLEN_MSS)) + return; + + len = optlen(opt); + offset += len; + if (opt[0] != TCPOPT_MSS || opt[1] != TCPOLEN_MSS) + continue; + + goto found; + } + return; + +found: + oldmss = (opt[2] << 8) | opt[3]; + if (oldmss <= mss) + return; + + opt[2] = mss >> 8; + opt[3] = mss & 0xff; + csum_replace2(&tcph->check, bpf_htons(oldmss), bpf_htons(mss)); +} + +SEC("tc") +int mssfix(struct __sk_buff *skb) +{ + __u32 offset = 0; + __u8 ipproto; + __u16 mss; + int type; + + type = parse_ethernet(skb, &offset); + if (type == bpf_htons(ETH_P_IP)) + type = parse_ipv4(skb, &offset); + else if (type == bpf_htons(ETH_P_IPV6)) + type = parse_ipv6(skb, &offset); + else + return TC_ACT_UNSPEC; + + if (type != IPPROTO_TCP) + return TC_ACT_UNSPEC; + + mss = mtu; + mss -= offset + sizeof(struct tcphdr); + fixup_tcp(skb, offset, mss); + + return TC_ACT_UNSPEC; +} diff --git a/rtnl.c b/rtnl.c new file mode 100644 index 0000000..4180575 --- /dev/null +++ b/rtnl.c @@ -0,0 +1,94 @@ +// SPDX-License-Identifier: GPL-2.0+ +/* + * Copyright (C) 2022 Felix Fietkau + */ +#define _GNU_SOURCE +#include +#include +#include +#include +#include +#include +#include "unetd.h" + +static struct nl_sock *rtnl; +bool rtnl_ignore_errors; + +static int +unetd_nl_error_cb(struct sockaddr_nl *nla, struct nlmsgerr *err, + void *arg) +{ + struct nlmsghdr *nlh = (struct nlmsghdr *) err - 1; + struct nlattr *tb[NLMSGERR_ATTR_MAX + 1]; + struct nlattr *attrs; + int ack_len = sizeof(*nlh) + sizeof(int) + sizeof(*nlh); + int len = nlh->nlmsg_len; + const char *errstr = "(unknown)"; + + if (rtnl_ignore_errors) + return NL_STOP; + + if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS)) + return NL_STOP; + + if (!(nlh->nlmsg_flags & NLM_F_CAPPED)) + ack_len += err->msg.nlmsg_len - sizeof(*nlh); + + attrs = (void *) ((unsigned char *) nlh + ack_len); + len -= ack_len; + + nla_parse(tb, NLMSGERR_ATTR_MAX, attrs, len, NULL); + if (tb[NLMSGERR_ATTR_MSG]) + errstr = nla_data(tb[NLMSGERR_ATTR_MSG]); + + D("Netlink error(%d): %s\n", err->error, errstr); + + return NL_STOP; +} + +int rtnl_call(struct nl_msg *msg) +{ + int ret; + + ret = nl_send_auto_complete(rtnl, msg); + nlmsg_free(msg); + + if (ret < 0) + return ret; + + return nl_wait_for_ack(rtnl); +} + +int rtnl_init(void) +{ + int fd, opt; + + if (rtnl) + return 0; + + rtnl = nl_socket_alloc(); + if (!rtnl) + return -1; + + if (nl_connect(rtnl, NETLINK_ROUTE)) + goto free; + + nl_socket_disable_seq_check(rtnl); + nl_socket_set_buffer_size(rtnl, 65536, 0); + nl_cb_err(nl_socket_get_cb(rtnl), NL_CB_CUSTOM, unetd_nl_error_cb, NULL); + + fd = nl_socket_get_fd(rtnl); + + opt = 1; + setsockopt(fd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, sizeof(opt)); + + opt = 1; + setsockopt(fd, SOL_NETLINK, NETLINK_CAP_ACK, &opt, sizeof(opt)); + + return 0; + +free: + nl_socket_free(rtnl); + rtnl = NULL; + return -1; +} diff --git a/unetd.h b/unetd.h index b6fd437..799bff7 100644 --- a/unetd.h +++ b/unetd.h @@ -19,6 +19,7 @@ #include "service.h" #include "ubus.h" +extern const char *mssfix_path; extern bool dummy_mode; extern bool debug; @@ -34,7 +35,10 @@ extern bool debug; #define D_PEER(net, peer, format, ...) D_NET(net, "host %s " format, network_peer_name(peer), ##__VA_ARGS__) #define D_SERVICE(net, service, format, ...) D_NET(net, "service %s " format, network_service_name(service), ##__VA_ARGS__) +#define UNETD_MSS_BPF_PATH "/lib/bpf/mss.o" +#define UNETD_MSS_PRIO_BASE 0x130 void unetd_write_hosts(void); +int unetd_attach_mssfix(int ifindex, int mtu); #endif diff --git a/utils.h b/utils.h index b024374..0770807 100644 --- a/utils.h +++ b/utils.h @@ -7,6 +7,8 @@ #include +struct nl_msg; + union network_addr { struct { uint8_t network_id[8]; @@ -82,4 +84,7 @@ static inline void bitmask_set_val(uint32_t *mask, unsigned int i, bool val) bitmask_clear(mask, i); } +int rtnl_init(void); +int rtnl_call(struct nl_msg *msg); + #endif diff --git a/vxlan.c b/vxlan.c index b15d4c3..f042a48 100644 --- a/vxlan.c +++ b/vxlan.c @@ -11,6 +11,8 @@ #include #include #include +#include +#include #include "unetd.h" struct vxlan_tunnel { @@ -26,39 +28,19 @@ struct vxlan_tunnel { bool active; }; -static struct nl_sock *rtnl; -static bool ignore_errors; - -static int -unetd_nl_error_cb(struct sockaddr_nl *nla, struct nlmsgerr *err, - void *arg) +static uint32_t +vxlan_tunnel_id(struct vxlan_tunnel *vt) { - struct nlmsghdr *nlh = (struct nlmsghdr *) err - 1; - struct nlattr *tb[NLMSGERR_ATTR_MAX + 1]; - struct nlattr *attrs; - int ack_len = sizeof(*nlh) + sizeof(int) + sizeof(*nlh); - int len = nlh->nlmsg_len; - const char *errstr = "(unknown)"; - - if (ignore_errors) - return NL_STOP; - - if (!(nlh->nlmsg_flags & NLM_F_ACK_TLVS)) - return NL_STOP; - - if (!(nlh->nlmsg_flags & NLM_F_CAPPED)) - ack_len += err->msg.nlmsg_len - sizeof(*nlh); - - attrs = (void *) ((unsigned char *) nlh + ack_len); - len -= ack_len; + siphash_key_t key = {}; + const char *name = network_service_name(vt->s); + uint64_t val; - nla_parse(tb, NLMSGERR_ATTR_MAX, attrs, len, NULL); - if (tb[NLMSGERR_ATTR_MSG]) - errstr = nla_data(tb[NLMSGERR_ATTR_MSG]); + if (vt->vni != ~0) + return vt->vni; - D("Netlink error(%d): %s\n", err->error, errstr); + siphash_to_le64(&val, name, strlen(name), &key); - return NL_STOP; + return val & 0x00ffffff; } static struct nl_msg *vxlan_rtnl_msg(const char *ifname, int type, int flags) @@ -78,69 +60,6 @@ static struct nl_msg *vxlan_rtnl_msg(const char *ifname, int type, int flags) return msg; } -static int vxlan_rtnl_call(struct nl_msg *msg) -{ - int ret; - - ret = nl_send_auto_complete(rtnl, msg); - nlmsg_free(msg); - - if (ret < 0) - return ret; - - return nl_wait_for_ack(rtnl); -} - -static int -vxlan_rtnl_init(void) -{ - int fd, opt; - - if (rtnl) - return 0; - - rtnl = nl_socket_alloc(); - if (!rtnl) - return -1; - - if (nl_connect(rtnl, NETLINK_ROUTE)) - goto free; - - nl_socket_disable_seq_check(rtnl); - nl_socket_set_buffer_size(rtnl, 65536, 0); - nl_cb_err(nl_socket_get_cb(rtnl), NL_CB_CUSTOM, unetd_nl_error_cb, NULL); - - fd = nl_socket_get_fd(rtnl); - - opt = 1; - setsockopt(fd, SOL_NETLINK, NETLINK_EXT_ACK, &opt, sizeof(opt)); - - opt = 1; - setsockopt(fd, SOL_NETLINK, NETLINK_CAP_ACK, &opt, sizeof(opt)); - - return 0; - -free: - nl_socket_free(rtnl); - rtnl = NULL; - return -1; -} - -static uint32_t -vxlan_tunnel_id(struct vxlan_tunnel *vt) -{ - siphash_key_t key = {}; - const char *name = network_service_name(vt->s); - uint64_t val; - - if (vt->vni != ~0) - return vt->vni; - - siphash_to_le64(&val, name, strlen(name), &key); - - return val & 0x00ffffff; -} - static int vxlan_update_host_fdb_entry(struct vxlan_tunnel *vt, struct network_host *host, bool add) { @@ -163,7 +82,7 @@ vxlan_update_host_fdb_entry(struct vxlan_tunnel *vt, struct network_host *host, nla_put(msg, NDA_DST, sizeof(struct in6_addr), &host->peer.local_addr); nla_put_u32(msg, NDA_IFINDEX, vt->net->ifindex); - return vxlan_rtnl_call(msg); + return rtnl_call(msg); } static void @@ -208,8 +127,9 @@ vxlan_tunnel_init(struct vxlan_tunnel *vt) struct nlattr *linkinfo, *data; struct nl_msg *msg; struct in6_addr group_addr; + int mtu; - if (vxlan_rtnl_init()) + if (rtnl_init()) return; memset(&group_addr, 0xff, sizeof(group_addr)); @@ -230,7 +150,7 @@ vxlan_tunnel_init(struct vxlan_tunnel *vt) nla_nest_end(msg, linkinfo); - if (vxlan_rtnl_call(msg) < 0) + if (rtnl_call(msg) < 0) return; vt->ifindex = if_nametoindex(vt->ifname); @@ -241,6 +161,9 @@ vxlan_tunnel_init(struct vxlan_tunnel *vt) vt->active = true; vxlan_update_fdb_hosts(vt); + + mtu = 1420 - sizeof(struct ipv6hdr) - sizeof(struct udphdr) - 8; + unetd_attach_mssfix(vt->ifindex, mtu); } static void @@ -248,12 +171,9 @@ vxlan_tunnel_teardown(struct vxlan_tunnel *vt) { struct nl_msg *msg; - if (!rtnl) - return; - vt->active = false; msg = vxlan_rtnl_msg(vt->ifname, RTM_DELLINK, 0); - vxlan_rtnl_call(msg); + rtnl_call(msg); } static const char * -- 2.30.2