summaryrefslogtreecommitdiff
path: root/net/netlink/af_netlink.c
diff options
context:
space:
mode:
Diffstat (limited to 'net/netlink/af_netlink.c')
-rw-r--r--net/netlink/af_netlink.c54
1 files changed, 40 insertions, 14 deletions
diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c
index faa48f70b7c9..9017e3ef8fee 100644
--- a/net/netlink/af_netlink.c
+++ b/net/netlink/af_netlink.c
@@ -137,6 +137,8 @@ static void netlink_destroy_callback(struct netlink_callback *cb);
static DEFINE_RWLOCK(nl_table_lock);
static atomic_t nl_table_users = ATOMIC_INIT(0);
+#define nl_deref_protected(X) rcu_dereference_protected(X, lockdep_is_held(&nl_table_lock));
+
static ATOMIC_NOTIFIER_HEAD(netlink_chain);
static inline u32 netlink_group_mask(u32 group)
@@ -156,6 +158,8 @@ static void netlink_sock_destruct(struct sock *sk)
if (nlk->cb) {
if (nlk->cb->done)
nlk->cb->done(nlk->cb);
+
+ module_put(nlk->cb->module);
netlink_destroy_callback(nlk->cb);
}
@@ -330,6 +334,11 @@ netlink_update_listeners(struct sock *sk)
struct hlist_node *node;
unsigned long mask;
unsigned int i;
+ struct listeners *listeners;
+
+ listeners = nl_deref_protected(tbl->listeners);
+ if (!listeners)
+ return;
for (i = 0; i < NLGRPLONGS(tbl->groups); i++) {
mask = 0;
@@ -337,7 +346,7 @@ netlink_update_listeners(struct sock *sk)
if (i < NLGRPLONGS(nlk_sk(sk)->ngroups))
mask |= nlk_sk(sk)->groups[i];
}
- tbl->listeners->masks[i] = mask;
+ listeners->masks[i] = mask;
}
/* this function is only called with the netlink table "grabbed", which
* makes sure updates are visible before bind or setsockopt return. */
@@ -518,7 +527,11 @@ static int netlink_release(struct socket *sock)
if (netlink_is_kernel(sk)) {
BUG_ON(nl_table[sk->sk_protocol].registered == 0);
if (--nl_table[sk->sk_protocol].registered == 0) {
- kfree(nl_table[sk->sk_protocol].listeners);
+ struct listeners *old;
+
+ old = nl_deref_protected(nl_table[sk->sk_protocol].listeners);
+ RCU_INIT_POINTER(nl_table[sk->sk_protocol].listeners, NULL);
+ kfree_rcu(old, rcu);
nl_table[sk->sk_protocol].module = NULL;
nl_table[sk->sk_protocol].registered = 0;
}
@@ -948,7 +961,7 @@ int netlink_has_listeners(struct sock *sk, unsigned int group)
rcu_read_lock();
listeners = rcu_dereference(nl_table[sk->sk_protocol].listeners);
- if (group - 1 < nl_table[sk->sk_protocol].groups)
+ if (listeners && group - 1 < nl_table[sk->sk_protocol].groups)
res = test_bit(group - 1, listeners->masks);
rcu_read_unlock();
@@ -1329,7 +1342,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
if (NULL == siocb->scm)
siocb->scm = &scm;
- err = scm_send(sock, msg, siocb->scm);
+ err = scm_send(sock, msg, siocb->scm, true);
if (err < 0)
return err;
@@ -1340,7 +1353,8 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock,
dst_pid = addr->nl_pid;
dst_group = ffs(addr->nl_groups);
err = -EPERM;
- if (dst_group && !netlink_capable(sock, NL_NONROOT_SEND))
+ if ((dst_group || dst_pid) &&
+ !netlink_capable(sock, NL_NONROOT_SEND))
goto out;
} else {
dst_pid = nlk->dst_pid;
@@ -1579,7 +1593,7 @@ int __netlink_change_ngroups(struct sock *sk, unsigned int groups)
new = kzalloc(sizeof(*new) + NLGRPSZ(groups), GFP_ATOMIC);
if (!new)
return -ENOMEM;
- old = rcu_dereference_protected(tbl->listeners, 1);
+ old = nl_deref_protected(tbl->listeners);
memcpy(new->masks, old->masks, NLGRPSZ(tbl->groups));
rcu_assign_pointer(tbl->listeners, new);
@@ -1727,6 +1741,7 @@ static int netlink_dump(struct sock *sk)
nlk->cb = NULL;
mutex_unlock(nlk->cb_mutex);
+ module_put(cb->module);
netlink_destroy_callback(cb);
return 0;
@@ -1736,9 +1751,9 @@ errout_skb:
return err;
}
-int netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
- const struct nlmsghdr *nlh,
- struct netlink_dump_control *control)
+int __netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
+ const struct nlmsghdr *nlh,
+ struct netlink_dump_control *control)
{
struct netlink_callback *cb;
struct sock *sk;
@@ -1753,6 +1768,7 @@ int netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
cb->done = control->done;
cb->nlh = nlh;
cb->data = control->data;
+ cb->module = control->module;
cb->min_dump_alloc = control->min_dump_alloc;
atomic_inc(&skb->users);
cb->skb = skb;
@@ -1763,19 +1779,28 @@ int netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
return -ECONNREFUSED;
}
nlk = nlk_sk(sk);
- /* A dump is in progress... */
+
mutex_lock(nlk->cb_mutex);
+ /* A dump is in progress... */
if (nlk->cb) {
mutex_unlock(nlk->cb_mutex);
netlink_destroy_callback(cb);
- sock_put(sk);
- return -EBUSY;
+ ret = -EBUSY;
+ goto out;
}
+ /* add reference of module which cb->dump belongs to */
+ if (!try_module_get(cb->module)) {
+ mutex_unlock(nlk->cb_mutex);
+ netlink_destroy_callback(cb);
+ ret = -EPROTONOSUPPORT;
+ goto out;
+ }
+
nlk->cb = cb;
mutex_unlock(nlk->cb_mutex);
ret = netlink_dump(sk);
-
+out:
sock_put(sk);
if (ret)
@@ -1786,7 +1811,7 @@ int netlink_dump_start(struct sock *ssk, struct sk_buff *skb,
*/
return -EINTR;
}
-EXPORT_SYMBOL(netlink_dump_start);
+EXPORT_SYMBOL(__netlink_dump_start);
void netlink_ack(struct sk_buff *in_skb, struct nlmsghdr *nlh, int err)
{
@@ -2115,6 +2140,7 @@ static void __init netlink_add_usersock_entry(void)
rcu_assign_pointer(nl_table[NETLINK_USERSOCK].listeners, listeners);
nl_table[NETLINK_USERSOCK].module = THIS_MODULE;
nl_table[NETLINK_USERSOCK].registered = 1;
+ nl_table[NETLINK_USERSOCK].nl_nonroot = NL_NONROOT_SEND;
netlink_table_ungrab();
}