net/smc: check for missing nlattrs in SMC_PNETID messages
authorEric Biggers <ebiggers@google.com>
Mon, 14 May 2018 00:01:30 +0000 (17:01 -0700)
committerDavid S. Miller <davem@davemloft.net>
Mon, 14 May 2018 20:40:45 +0000 (16:40 -0400)
It's possible to crash the kernel in several different ways by sending
messages to the SMC_PNETID generic netlink family that are missing the
expected attributes:

- Missing SMC_PNETID_NAME => null pointer dereference when comparing
  names.
- Missing SMC_PNETID_ETHNAME => null pointer dereference accessing
  smc_pnetentry::ndev.
- Missing SMC_PNETID_IBNAME => null pointer dereference accessing
  smc_pnetentry::smcibdev.
- Missing SMC_PNETID_IBPORT => out of bounds array access to
  smc_ib_device::pattr[-1].

Fix it by validating that all expected attributes are present and that
SMC_PNETID_IBPORT is nonzero.

Reported-by: syzbot+5cd61039dc9b8bfa6e47@syzkaller.appspotmail.com
Fixes: 6812baabf24d ("smc: establish pnet table management")
Cc: <stable@vger.kernel.org> # v4.11+
Signed-off-by: Eric Biggers <ebiggers@google.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
net/smc/smc_pnet.c

index 74568cdbca7087532c20b891496edeb467b13829..d7b88b2d1b224195b2d82523c047052c67f2e1eb 100644 (file)
@@ -245,40 +245,45 @@ out:
 static int smc_pnet_fill_entry(struct net *net, struct smc_pnetentry *pnetelem,
                               struct nlattr *tb[])
 {
-       char *string, *ibname = NULL;
-       int rc = 0;
+       char *string, *ibname;
+       int rc;
 
        memset(pnetelem, 0, sizeof(*pnetelem));
        INIT_LIST_HEAD(&pnetelem->list);
-       if (tb[SMC_PNETID_NAME]) {
-               string = (char *)nla_data(tb[SMC_PNETID_NAME]);
-               if (!smc_pnetid_valid(string, pnetelem->pnet_name)) {
-                       rc = -EINVAL;
-                       goto error;
-               }
-       }
-       if (tb[SMC_PNETID_ETHNAME]) {
-               string = (char *)nla_data(tb[SMC_PNETID_ETHNAME]);
-               pnetelem->ndev = dev_get_by_name(net, string);
-               if (!pnetelem->ndev)
-                       return -ENOENT;
-       }
-       if (tb[SMC_PNETID_IBNAME]) {
-               ibname = (char *)nla_data(tb[SMC_PNETID_IBNAME]);
-               ibname = strim(ibname);
-               pnetelem->smcibdev = smc_pnet_find_ib(ibname);
-               if (!pnetelem->smcibdev) {
-                       rc = -ENOENT;
-                       goto error;
-               }
-       }
-       if (tb[SMC_PNETID_IBPORT]) {
-               pnetelem->ib_port = nla_get_u8(tb[SMC_PNETID_IBPORT]);
-               if (pnetelem->ib_port > SMC_MAX_PORTS) {
-                       rc = -EINVAL;
-                       goto error;
-               }
-       }
+
+       rc = -EINVAL;
+       if (!tb[SMC_PNETID_NAME])
+               goto error;
+       string = (char *)nla_data(tb[SMC_PNETID_NAME]);
+       if (!smc_pnetid_valid(string, pnetelem->pnet_name))
+               goto error;
+
+       rc = -EINVAL;
+       if (!tb[SMC_PNETID_ETHNAME])
+               goto error;
+       rc = -ENOENT;
+       string = (char *)nla_data(tb[SMC_PNETID_ETHNAME]);
+       pnetelem->ndev = dev_get_by_name(net, string);
+       if (!pnetelem->ndev)
+               goto error;
+
+       rc = -EINVAL;
+       if (!tb[SMC_PNETID_IBNAME])
+               goto error;
+       rc = -ENOENT;
+       ibname = (char *)nla_data(tb[SMC_PNETID_IBNAME]);
+       ibname = strim(ibname);
+       pnetelem->smcibdev = smc_pnet_find_ib(ibname);
+       if (!pnetelem->smcibdev)
+               goto error;
+
+       rc = -EINVAL;
+       if (!tb[SMC_PNETID_IBPORT])
+               goto error;
+       pnetelem->ib_port = nla_get_u8(tb[SMC_PNETID_IBPORT]);
+       if (pnetelem->ib_port < 1 || pnetelem->ib_port > SMC_MAX_PORTS)
+               goto error;
+
        return 0;
 
 error:
@@ -307,6 +312,8 @@ static int smc_pnet_get(struct sk_buff *skb, struct genl_info *info)
        void *hdr;
        int rc;
 
+       if (!info->attrs[SMC_PNETID_NAME])
+               return -EINVAL;
        pnetelem = smc_pnet_find_pnetid(
                                (char *)nla_data(info->attrs[SMC_PNETID_NAME]));
        if (!pnetelem)
@@ -359,6 +366,8 @@ static int smc_pnet_add(struct sk_buff *skb, struct genl_info *info)
 
 static int smc_pnet_del(struct sk_buff *skb, struct genl_info *info)
 {
+       if (!info->attrs[SMC_PNETID_NAME])
+               return -EINVAL;
        return smc_pnet_remove_by_pnetid(
                                (char *)nla_data(info->attrs[SMC_PNETID_NAME]));
 }