X-Git-Url: https://err.no/cgi-bin/gitweb.cgi?a=blobdiff_plain;f=net%2Fnetlink%2Faf_netlink.c;h=62435ffc61846396e6a862dd9907319a1a3b8688;hb=5170dbebbb2e9159cdf6bbf35e5d79cd7009799a;hp=7b7b45a195979dc46efee171bc7a322275023ecc;hpb=db080529798b497eb5a37b92a25e966be5a7dd5d;p=linux-2.6 diff --git a/net/netlink/af_netlink.c b/net/netlink/af_netlink.c index 7b7b45a195..62435ffc61 100644 --- a/net/netlink/af_netlink.c +++ b/net/netlink/af_netlink.c @@ -60,21 +60,29 @@ #include #define Nprintk(a...) +#define NLGRPSZ(x) (ALIGN(x, sizeof(unsigned long) * 8) / 8) struct netlink_sock { /* struct sock has to be the first member of netlink_sock */ struct sock sk; u32 pid; - unsigned int groups; u32 dst_pid; - unsigned int dst_groups; + u32 dst_group; + u32 flags; + u32 subscriptions; + u32 ngroups; + unsigned long *groups; unsigned long state; wait_queue_head_t wait; struct netlink_callback *cb; spinlock_t cb_lock; void (*data_ready)(struct sock *sk, int bytes); + struct module *module; }; +#define NETLINK_KERNEL_SOCKET 0x1 +#define NETLINK_RECV_PKTINFO 0x2 + static inline struct netlink_sock *nlk_sk(struct sock *sk) { return (struct netlink_sock *)sk; @@ -97,7 +105,9 @@ struct netlink_table { struct nl_pid_hash hash; struct hlist_head mc_list; unsigned int nl_nonroot; - struct proto_ops *p_ops; + unsigned int groups; + struct module *module; + int registered; }; static struct netlink_table *nl_table; @@ -112,6 +122,11 @@ static atomic_t nl_table_users = ATOMIC_INIT(0); static struct notifier_block *netlink_chain; +static u32 netlink_group_mask(u32 group) +{ + return group ? 1 << (group - 1) : 0; +} + static struct hlist_head *nl_pid_hashfn(struct nl_pid_hash *hash, u32 pid) { return &hash->table[jhash_1word(pid, hash->rnd) & hash->mask]; @@ -128,6 +143,7 @@ static void netlink_sock_destruct(struct sock *sk) BUG_TRAP(!atomic_read(&sk->sk_rmem_alloc)); BUG_TRAP(!atomic_read(&sk->sk_wmem_alloc)); BUG_TRAP(!nlk_sk(sk)->cb); + BUG_TRAP(!nlk_sk(sk)->groups); } /* This lock without WQ_FLAG_EXCLUSIVE is good on UP and it is _very_ bad on SMP. @@ -323,7 +339,7 @@ static void netlink_remove(struct sock *sk) netlink_table_grab(); if (sk_del_node_init(sk)) nl_table[sk->sk_protocol].hash.entries--; - if (nlk_sk(sk)->groups) + if (nlk_sk(sk)->subscriptions) __sk_del_bind_node(sk); netlink_table_ungrab(); } @@ -334,11 +350,35 @@ static struct proto netlink_proto = { .obj_size = sizeof(struct netlink_sock), }; -static int netlink_create(struct socket *sock, int protocol) +static int __netlink_create(struct socket *sock, int protocol) { struct sock *sk; struct netlink_sock *nlk; + sock->ops = &netlink_ops; + + sk = sk_alloc(PF_NETLINK, GFP_KERNEL, &netlink_proto, 1); + if (!sk) + return -ENOMEM; + + sock_init_data(sock, sk); + + nlk = nlk_sk(sk); + spin_lock_init(&nlk->cb_lock); + init_waitqueue_head(&nlk->wait); + + sk->sk_destruct = netlink_sock_destruct; + sk->sk_protocol = protocol; + return 0; +} + +static int netlink_create(struct socket *sock, int protocol) +{ + struct module *module = NULL; + struct netlink_sock *nlk; + unsigned int groups; + int err = 0; + sock->state = SS_UNCONNECTED; if (sock->type != SOCK_RAW && sock->type != SOCK_DGRAM) @@ -347,36 +387,42 @@ static int netlink_create(struct socket *sock, int protocol) if (protocol<0 || protocol >= MAX_LINKS) return -EPROTONOSUPPORT; - netlink_table_grab(); - if (!nl_table[protocol].hash.entries) { + netlink_lock_table(); #ifdef CONFIG_KMOD - /* We do 'best effort'. If we find a matching module, - * it is loaded. If not, we don't return an error to - * allow pure userspace<->userspace communication. -HW - */ - netlink_table_ungrab(); + if (!nl_table[protocol].registered) { + netlink_unlock_table(); request_module("net-pf-%d-proto-%d", PF_NETLINK, protocol); - netlink_table_grab(); -#endif + netlink_lock_table(); } - netlink_table_ungrab(); +#endif + if (nl_table[protocol].registered && + try_module_get(nl_table[protocol].module)) + module = nl_table[protocol].module; + else + err = -EPROTONOSUPPORT; + groups = nl_table[protocol].groups; + netlink_unlock_table(); - sock->ops = nl_table[protocol].p_ops; + if (err || (err = __netlink_create(sock, protocol) < 0)) + goto out_module; - sk = sk_alloc(PF_NETLINK, GFP_KERNEL, &netlink_proto, 1); - if (!sk) - return -ENOMEM; + nlk = nlk_sk(sock->sk); - sock_init_data(sock, sk); - - nlk = nlk_sk(sk); + nlk->groups = kmalloc(NLGRPSZ(groups), GFP_KERNEL); + if (nlk->groups == NULL) { + err = -ENOMEM; + goto out_module; + } + memset(nlk->groups, 0, NLGRPSZ(groups)); + nlk->ngroups = groups; - spin_lock_init(&nlk->cb_lock); - init_waitqueue_head(&nlk->wait); - sk->sk_destruct = netlink_sock_destruct; + nlk->module = module; +out: + return err; - sk->sk_protocol = protocol; - return 0; +out_module: + module_put(module); + goto out; } static int netlink_release(struct socket *sock) @@ -407,7 +453,7 @@ static int netlink_release(struct socket *sock) skb_queue_purge(&sk->sk_write_queue); - if (nlk->pid && !nlk->groups) { + if (nlk->pid && !nlk->subscriptions) { struct netlink_notify n = { .protocol = sk->sk_protocol, .pid = nlk->pid, @@ -415,22 +461,19 @@ static int netlink_release(struct socket *sock) notifier_call_chain(&netlink_chain, NETLINK_URELEASE, &n); } - /* When this is a kernel socket, we need to remove the owner pointer, - * since we don't know whether the module will be dying at any given - * point - HW - */ - if (!nlk->pid) { - struct proto_ops *p_tmp; + if (nlk->module) + module_put(nlk->module); + if (nlk->flags & NETLINK_KERNEL_SOCKET) { netlink_table_grab(); - p_tmp = nl_table[sk->sk_protocol].p_ops; - if (p_tmp != &netlink_ops) { - nl_table[sk->sk_protocol].p_ops = &netlink_ops; - kfree(p_tmp); - } + nl_table[sk->sk_protocol].module = NULL; + nl_table[sk->sk_protocol].registered = 0; netlink_table_ungrab(); } - + + kfree(nlk->groups); + nlk->groups = NULL; + sock_put(sk); return 0; } @@ -479,6 +522,18 @@ static inline int netlink_capable(struct socket *sock, unsigned int flag) capable(CAP_NET_ADMIN); } +static void +netlink_update_subscriptions(struct sock *sk, unsigned int subscriptions) +{ + struct netlink_sock *nlk = nlk_sk(sk); + + if (nlk->subscriptions && !subscriptions) + __sk_del_bind_node(sk); + else if (!nlk->subscriptions && subscriptions) + sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list); + nlk->subscriptions = subscriptions; +} + static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len) { struct sock *sk = sock->sk; @@ -504,15 +559,14 @@ static int netlink_bind(struct socket *sock, struct sockaddr *addr, int addr_len return err; } - if (!nladdr->nl_groups && !nlk->groups) + if (!nladdr->nl_groups && !(u32)nlk->groups[0]) return 0; netlink_table_grab(); - if (nlk->groups && !nladdr->nl_groups) - __sk_del_bind_node(sk); - else if (!nlk->groups && nladdr->nl_groups) - sk_add_bind_node(sk, &nl_table[sk->sk_protocol].mc_list); - nlk->groups = nladdr->nl_groups; + netlink_update_subscriptions(sk, nlk->subscriptions + + hweight32(nladdr->nl_groups) - + hweight32(nlk->groups[0])); + nlk->groups[0] = (nlk->groups[0] & ~0xffffffffUL) | nladdr->nl_groups; netlink_table_ungrab(); return 0; @@ -529,7 +583,7 @@ static int netlink_connect(struct socket *sock, struct sockaddr *addr, if (addr->sa_family == AF_UNSPEC) { sk->sk_state = NETLINK_UNCONNECTED; nlk->dst_pid = 0; - nlk->dst_groups = 0; + nlk->dst_group = 0; return 0; } if (addr->sa_family != AF_NETLINK) @@ -545,7 +599,7 @@ static int netlink_connect(struct socket *sock, struct sockaddr *addr, if (err == 0) { sk->sk_state = NETLINK_CONNECTED; nlk->dst_pid = nladdr->nl_pid; - nlk->dst_groups = nladdr->nl_groups; + nlk->dst_group = ffs(nladdr->nl_groups); } return err; @@ -563,10 +617,10 @@ static int netlink_getname(struct socket *sock, struct sockaddr *addr, int *addr if (peer) { nladdr->nl_pid = nlk->dst_pid; - nladdr->nl_groups = nlk->dst_groups; + nladdr->nl_groups = netlink_group_mask(nlk->dst_group); } else { nladdr->nl_pid = nlk->pid; - nladdr->nl_groups = nlk->groups; + nladdr->nl_groups = nlk->groups[0]; } return 0; } @@ -767,7 +821,8 @@ static inline int do_one_broadcast(struct sock *sk, if (p->exclude_sk == sk) goto out; - if (nlk->pid == p->pid || !(nlk->groups & p->group)) + if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups || + !test_bit(p->group - 1, nlk->groups)) goto out; if (p->failure) { @@ -806,7 +861,7 @@ out: } int netlink_broadcast(struct sock *ssk, struct sk_buff *skb, u32 pid, - u32 group, int allocation) + u32 group, unsigned int __nocast allocation) { struct netlink_broadcast_data info; struct hlist_node *node; @@ -863,7 +918,8 @@ static inline int do_one_set_err(struct sock *sk, if (sk == p->exclude_sk) goto out; - if (nlk->pid == p->pid || !(nlk->groups & p->group)) + if (nlk->pid == p->pid || p->group - 1 >= nlk->ngroups || + !test_bit(p->group - 1, nlk->groups)) goto out; sk->sk_err = p->code; @@ -891,6 +947,94 @@ void netlink_set_err(struct sock *ssk, u32 pid, u32 group, int code) read_unlock(&nl_table_lock); } +static int netlink_setsockopt(struct socket *sock, int level, int optname, + char __user *optval, int optlen) +{ + struct sock *sk = sock->sk; + struct netlink_sock *nlk = nlk_sk(sk); + int val = 0, err; + + if (level != SOL_NETLINK) + return -ENOPROTOOPT; + + if (optlen >= sizeof(int) && + get_user(val, (int __user *)optval)) + return -EFAULT; + + switch (optname) { + case NETLINK_PKTINFO: + if (val) + nlk->flags |= NETLINK_RECV_PKTINFO; + else + nlk->flags &= ~NETLINK_RECV_PKTINFO; + err = 0; + break; + case NETLINK_ADD_MEMBERSHIP: + case NETLINK_DROP_MEMBERSHIP: { + unsigned int subscriptions; + int old, new = optname == NETLINK_ADD_MEMBERSHIP ? 1 : 0; + + if (!netlink_capable(sock, NL_NONROOT_RECV)) + return -EPERM; + if (!val || val - 1 >= nlk->ngroups) + return -EINVAL; + netlink_table_grab(); + old = test_bit(val - 1, nlk->groups); + subscriptions = nlk->subscriptions - old + new; + if (new) + __set_bit(val - 1, nlk->groups); + else + __clear_bit(val - 1, nlk->groups); + netlink_update_subscriptions(sk, subscriptions); + netlink_table_ungrab(); + err = 0; + break; + } + default: + err = -ENOPROTOOPT; + } + return err; +} + +static int netlink_getsockopt(struct socket *sock, int level, int optname, + char __user *optval, int __user *optlen) +{ + struct sock *sk = sock->sk; + struct netlink_sock *nlk = nlk_sk(sk); + int len, val, err; + + if (level != SOL_NETLINK) + return -ENOPROTOOPT; + + if (get_user(len, optlen)) + return -EFAULT; + if (len < 0) + return -EINVAL; + + switch (optname) { + case NETLINK_PKTINFO: + if (len < sizeof(int)) + return -EINVAL; + len = sizeof(int); + val = nlk->flags & NETLINK_RECV_PKTINFO ? 1 : 0; + put_user(len, optlen); + put_user(val, optval); + err = 0; + break; + default: + err = -ENOPROTOOPT; + } + return err; +} + +static void netlink_cmsg_recv_pktinfo(struct msghdr *msg, struct sk_buff *skb) +{ + struct nl_pktinfo info; + + info.group = NETLINK_CB(skb).dst_group; + put_cmsg(msg, SOL_NETLINK, NETLINK_PKTINFO, sizeof(info), &info); +} + static inline void netlink_rcv_wake(struct sock *sk) { struct netlink_sock *nlk = nlk_sk(sk); @@ -909,7 +1053,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, struct netlink_sock *nlk = nlk_sk(sk); struct sockaddr_nl *addr=msg->msg_name; u32 dst_pid; - u32 dst_groups; + u32 dst_group; struct sk_buff *skb; int err; struct scm_cookie scm; @@ -927,12 +1071,12 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, if (addr->nl_family != AF_NETLINK) return -EINVAL; dst_pid = addr->nl_pid; - dst_groups = addr->nl_groups; - if (dst_groups && !netlink_capable(sock, NL_NONROOT_SEND)) + dst_group = ffs(addr->nl_groups); + if (dst_group && !netlink_capable(sock, NL_NONROOT_SEND)) return -EPERM; } else { dst_pid = nlk->dst_pid; - dst_groups = nlk->dst_groups; + dst_group = nlk->dst_group; } if (!nlk->pid) { @@ -951,7 +1095,7 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, NETLINK_CB(skb).pid = nlk->pid; NETLINK_CB(skb).dst_pid = dst_pid; - NETLINK_CB(skb).dst_groups = dst_groups; + NETLINK_CB(skb).dst_group = dst_group; NETLINK_CB(skb).loginuid = audit_get_loginuid(current->audit_context); memcpy(NETLINK_CREDS(skb), &siocb->scm->creds, sizeof(struct ucred)); @@ -973,9 +1117,9 @@ static int netlink_sendmsg(struct kiocb *kiocb, struct socket *sock, goto out; } - if (dst_groups) { + if (dst_group) { atomic_inc(&skb->users); - netlink_broadcast(sk, skb, dst_pid, dst_groups, GFP_KERNEL); + netlink_broadcast(sk, skb, dst_pid, dst_group, GFP_KERNEL); } err = netlink_unicast(sk, skb, dst_pid, msg->msg_flags&MSG_DONTWAIT); @@ -1021,7 +1165,7 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock, addr->nl_family = AF_NETLINK; addr->nl_pad = 0; addr->nl_pid = NETLINK_CB(skb).pid; - addr->nl_groups = NETLINK_CB(skb).dst_groups; + addr->nl_groups = netlink_group_mask(NETLINK_CB(skb).dst_group); msg->msg_namelen = sizeof(*addr); } @@ -1036,6 +1180,8 @@ static int netlink_recvmsg(struct kiocb *kiocb, struct socket *sock, netlink_dump(sk); scm_recv(sock, msg, siocb->scm, flags); + if (nlk->flags & NETLINK_RECV_PKTINFO) + netlink_cmsg_recv_pktinfo(msg, skb); out: netlink_rcv_wake(sk); @@ -1058,11 +1204,13 @@ static void netlink_data_ready(struct sock *sk, int len) */ struct sock * -netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct module *module) +netlink_kernel_create(int unit, unsigned int groups, + void (*input)(struct sock *sk, int len), + struct module *module) { - struct proto_ops *p_ops; struct socket *sock; struct sock *sk; + struct netlink_sock *nlk; if (!nl_table) return NULL; @@ -1070,64 +1218,34 @@ netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct if (unit<0 || unit>=MAX_LINKS) return NULL; - /* Do a quick check, to make us not go down to netlink_insert() - * if protocol already has kernel socket. - */ - sk = netlink_lookup(unit, 0); - if (unlikely(sk)) { - sock_put(sk); - return NULL; - } - if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock)) return NULL; - sk = NULL; - if (module) { - /* Every registering protocol implemented in a module needs - * it's own p_ops, since the socket code cannot deal with - * module refcounting otherwise. -HW - */ - p_ops = kmalloc(sizeof(*p_ops), GFP_KERNEL); - if (!p_ops) - goto out_sock_release; - - memcpy(p_ops, &netlink_ops, sizeof(*p_ops)); - p_ops->owner = module; - } else - p_ops = &netlink_ops; - - netlink_table_grab(); - nl_table[unit].p_ops = p_ops; - netlink_table_ungrab(); - - if (netlink_create(sock, unit) < 0) { - sk = NULL; - goto out_kfree_p_ops; - } + if (__netlink_create(sock, unit) < 0) + goto out_sock_release; sk = sock->sk; sk->sk_data_ready = netlink_data_ready; if (input) nlk_sk(sk)->data_ready = input; - if (netlink_insert(sk, 0)) { - sk = NULL; - goto out_kfree_p_ops; - } + if (netlink_insert(sk, 0)) + goto out_sock_release; - return sk; + nlk = nlk_sk(sk); + nlk->flags |= NETLINK_KERNEL_SOCKET; -out_kfree_p_ops: netlink_table_grab(); - if (nl_table[unit].p_ops != &netlink_ops) { - kfree(nl_table[unit].p_ops); - nl_table[unit].p_ops = &netlink_ops; - } + nl_table[unit].groups = groups < 32 ? 32 : groups; + nl_table[unit].module = module; + nl_table[unit].registered = 1; netlink_table_ungrab(); + + return sk; + out_sock_release: sock_release(sock); - return sk; + return NULL; } void netlink_set_nonroot(int protocol, unsigned int flags) @@ -1365,7 +1483,8 @@ static int netlink_seq_show(struct seq_file *seq, void *v) s, s->sk_protocol, nlk->pid, - nlk->groups, + nlk->flags & NETLINK_KERNEL_SOCKET ? + 0 : (unsigned int)nlk->groups[0], atomic_read(&s->sk_rmem_alloc), atomic_read(&s->sk_wmem_alloc), nlk->cb, @@ -1439,8 +1558,8 @@ static struct proto_ops netlink_ops = { .ioctl = sock_no_ioctl, .listen = sock_no_listen, .shutdown = sock_no_shutdown, - .setsockopt = sock_no_setsockopt, - .getsockopt = sock_no_getsockopt, + .setsockopt = netlink_setsockopt, + .getsockopt = netlink_getsockopt, .sendmsg = netlink_sendmsg, .recvmsg = netlink_recvmsg, .mmap = sock_no_mmap, @@ -1490,8 +1609,6 @@ enomem: for (i = 0; i < MAX_LINKS; i++) { struct nl_pid_hash *hash = &nl_table[i].hash; - nl_table[i].p_ops = &netlink_ops; - hash->table = nl_pid_hash_alloc(1 * sizeof(*hash->table)); if (!hash->table) { while (i-- > 0)