]> err.no Git - linux-2.6/blobdiff - net/ipv4/netfilter/ip_nat_core.c
[NETFILTER]: Handle NAT module load race
[linux-2.6] / net / ipv4 / netfilter / ip_nat_core.c
index 9fc6f93af0dd15de8933a14e053bf8cdf41dfb69..1adedb743f609f8b349a545184040d54512f639d 100644 (file)
@@ -22,8 +22,8 @@
 #include <linux/udp.h>
 #include <linux/jhash.h>
 
-#define ASSERT_READ_LOCK(x) MUST_BE_READ_LOCKED(&ip_nat_lock)
-#define ASSERT_WRITE_LOCK(x) MUST_BE_WRITE_LOCKED(&ip_nat_lock)
+#define ASSERT_READ_LOCK(x)
+#define ASSERT_WRITE_LOCK(x)
 
 #include <linux/netfilter_ipv4/ip_conntrack.h>
 #include <linux/netfilter_ipv4/ip_conntrack_core.h>
 #define DEBUGP(format, args...)
 #endif
 
-DECLARE_RWLOCK(ip_nat_lock);
+DEFINE_RWLOCK(ip_nat_lock);
 
 /* Calculated at init based on memory size */
 static unsigned int ip_nat_htable_size;
 
 static struct list_head *bysource;
+
+#define MAX_IP_NAT_PROTO 256
 struct ip_nat_protocol *ip_nat_protos[MAX_IP_NAT_PROTO];
 
+static inline struct ip_nat_protocol *
+__ip_nat_proto_find(u_int8_t protonum)
+{
+       return ip_nat_protos[protonum];
+}
+
+struct ip_nat_protocol *
+ip_nat_proto_find_get(u_int8_t protonum)
+{
+       struct ip_nat_protocol *p;
+
+       /* we need to disable preemption to make sure 'p' doesn't get
+        * removed until we've grabbed the reference */
+       preempt_disable();
+       p = __ip_nat_proto_find(protonum);
+       if (p) {
+               if (!try_module_get(p->me))
+                       p = &ip_nat_unknown_protocol;
+       }
+       preempt_enable();
+
+       return p;
+}
+
+void
+ip_nat_proto_put(struct ip_nat_protocol *p)
+{
+       module_put(p->me);
+}
 
 /* We keep an extra hash for each conntrack, for fast searching. */
 static inline unsigned int
@@ -65,9 +96,9 @@ static void ip_nat_cleanup_conntrack(struct ip_conntrack *conn)
        if (!(conn->status & IPS_NAT_DONE_MASK))
                return;
 
-       WRITE_LOCK(&ip_nat_lock);
+       write_lock_bh(&ip_nat_lock);
        list_del(&conn->nat.info.bysource);
-       WRITE_UNLOCK(&ip_nat_lock);
+       write_unlock_bh(&ip_nat_lock);
 }
 
 /* We do checksum mangling, so if they were wrong before they're still
@@ -103,7 +134,8 @@ static int
 in_range(const struct ip_conntrack_tuple *tuple,
         const struct ip_nat_range *range)
 {
-       struct ip_nat_protocol *proto = ip_nat_find_proto(tuple->dst.protonum);
+       struct ip_nat_protocol *proto = 
+                               __ip_nat_proto_find(tuple->dst.protonum);
 
        /* If we are supposed to map IPs, then we must be in the
           range specified, otherwise let this drag us onto a new src IP. */
@@ -142,7 +174,7 @@ find_appropriate_src(const struct ip_conntrack_tuple *tuple,
        unsigned int h = hash_by_src(tuple);
        struct ip_conntrack *ct;
 
-       READ_LOCK(&ip_nat_lock);
+       read_lock_bh(&ip_nat_lock);
        list_for_each_entry(ct, &bysource[h], nat.info.bysource) {
                if (same_src(ct, tuple)) {
                        /* Copy source part from reply tuple. */
@@ -151,12 +183,12 @@ find_appropriate_src(const struct ip_conntrack_tuple *tuple,
                        result->dst = tuple->dst;
 
                        if (in_range(result, range)) {
-                               READ_UNLOCK(&ip_nat_lock);
+                               read_unlock_bh(&ip_nat_lock);
                                return 1;
                        }
                }
        }
-       READ_UNLOCK(&ip_nat_lock);
+       read_unlock_bh(&ip_nat_lock);
        return 0;
 }
 
@@ -216,8 +248,7 @@ get_unique_tuple(struct ip_conntrack_tuple *tuple,
                 struct ip_conntrack *conntrack,
                 enum ip_nat_manip_type maniptype)
 {
-       struct ip_nat_protocol *proto
-               = ip_nat_find_proto(orig_tuple->dst.protonum);
+       struct ip_nat_protocol *proto;
 
        /* 1) If this srcip/proto/src-proto-part is currently mapped,
           and that same mapping gives a unique tuple within the given
@@ -242,14 +273,20 @@ get_unique_tuple(struct ip_conntrack_tuple *tuple,
        /* 3) The per-protocol part of the manip is made to map into
           the range to make a unique tuple. */
 
+       proto = ip_nat_proto_find_get(orig_tuple->dst.protonum);
+
        /* Only bother mapping if it's not already in range and unique */
        if ((!(range->flags & IP_NAT_RANGE_PROTO_SPECIFIED)
             || proto->in_range(tuple, maniptype, &range->min, &range->max))
-           && !ip_nat_used_tuple(tuple, conntrack))
+           && !ip_nat_used_tuple(tuple, conntrack)) {
+               ip_nat_proto_put(proto);
                return;
+       }
 
        /* Last change: get protocol to try to obtain unique tuple. */
        proto->unique_tuple(tuple, range, maniptype, conntrack);
+
+       ip_nat_proto_put(proto);
 }
 
 unsigned int
@@ -297,9 +334,9 @@ ip_nat_setup_info(struct ip_conntrack *conntrack,
                unsigned int srchash
                        = hash_by_src(&conntrack->tuplehash[IP_CT_DIR_ORIGINAL]
                                      .tuple);
-               WRITE_LOCK(&ip_nat_lock);
+               write_lock_bh(&ip_nat_lock);
                list_add(&info->bysource, &bysource[srchash]);
-               WRITE_UNLOCK(&ip_nat_lock);
+               write_unlock_bh(&ip_nat_lock);
        }
 
        /* It's done. */
@@ -320,17 +357,20 @@ manip_pkt(u_int16_t proto,
          enum ip_nat_manip_type maniptype)
 {
        struct iphdr *iph;
+       struct ip_nat_protocol *p;
 
-       (*pskb)->nfcache |= NFC_ALTERED;
-       if (!skb_ip_make_writable(pskb, iphdroff + sizeof(*iph)))
+       if (!skb_make_writable(pskb, iphdroff + sizeof(*iph)))
                return 0;
 
        iph = (void *)(*pskb)->data + iphdroff;
 
        /* Manipulate protcol part. */
-       if (!ip_nat_find_proto(proto)->manip_pkt(pskb, iphdroff,
-                                                target, maniptype))
+       p = ip_nat_proto_find_get(proto);
+       if (!p->manip_pkt(pskb, iphdroff, target, maniptype)) {
+               ip_nat_proto_put(p);
                return 0;
+       }
+       ip_nat_proto_put(p);
 
        iph = (void *)(*pskb)->data + iphdroff;
 
@@ -391,7 +431,7 @@ int icmp_reply_translation(struct sk_buff **pskb,
        struct ip_conntrack_tuple inner, target;
        int hdrlen = (*pskb)->nh.iph->ihl * 4;
 
-       if (!skb_ip_make_writable(pskb, hdrlen + sizeof(*inside)))
+       if (!skb_make_writable(pskb, hdrlen + sizeof(*inside)))
                return 0;
 
        inside = (void *)(*pskb)->data + (*pskb)->nh.iph->ihl*4;
@@ -426,7 +466,8 @@ int icmp_reply_translation(struct sk_buff **pskb,
 
        if (!ip_ct_get_tuple(&inside->ip, *pskb, (*pskb)->nh.iph->ihl*4 +
                             sizeof(struct icmphdr) + inside->ip.ihl*4,
-                            &inner, ip_ct_find_proto(inside->ip.protocol)))
+                            &inner,
+                            __ip_conntrack_proto_find(inside->ip.protocol)))
                return 0;
 
        /* Change inner back to look like incoming packet.  We do the
@@ -474,28 +515,71 @@ int ip_nat_protocol_register(struct ip_nat_protocol *proto)
 {
        int ret = 0;
 
-       WRITE_LOCK(&ip_nat_lock);
+       write_lock_bh(&ip_nat_lock);
        if (ip_nat_protos[proto->protonum] != &ip_nat_unknown_protocol) {
                ret = -EBUSY;
                goto out;
        }
        ip_nat_protos[proto->protonum] = proto;
  out:
-       WRITE_UNLOCK(&ip_nat_lock);
+       write_unlock_bh(&ip_nat_lock);
        return ret;
 }
 
 /* Noone stores the protocol anywhere; simply delete it. */
 void ip_nat_protocol_unregister(struct ip_nat_protocol *proto)
 {
-       WRITE_LOCK(&ip_nat_lock);
+       write_lock_bh(&ip_nat_lock);
        ip_nat_protos[proto->protonum] = &ip_nat_unknown_protocol;
-       WRITE_UNLOCK(&ip_nat_lock);
+       write_unlock_bh(&ip_nat_lock);
 
        /* Someone could be still looking at the proto in a bh. */
        synchronize_net();
 }
 
+#if defined(CONFIG_IP_NF_CONNTRACK_NETLINK) || \
+    defined(CONFIG_IP_NF_CONNTRACK_NETLINK_MODULE)
+int
+ip_nat_port_range_to_nfattr(struct sk_buff *skb, 
+                           const struct ip_nat_range *range)
+{
+       NFA_PUT(skb, CTA_PROTONAT_PORT_MIN, sizeof(u_int16_t),
+               &range->min.tcp.port);
+       NFA_PUT(skb, CTA_PROTONAT_PORT_MAX, sizeof(u_int16_t),
+               &range->max.tcp.port);
+
+       return 0;
+
+nfattr_failure:
+       return -1;
+}
+
+int
+ip_nat_port_nfattr_to_range(struct nfattr *tb[], struct ip_nat_range *range)
+{
+       int ret = 0;
+       
+       /* we have to return whether we actually parsed something or not */
+
+       if (tb[CTA_PROTONAT_PORT_MIN-1]) {
+               ret = 1;
+               range->min.tcp.port = 
+                       *(u_int16_t *)NFA_DATA(tb[CTA_PROTONAT_PORT_MIN-1]);
+       }
+       
+       if (!tb[CTA_PROTONAT_PORT_MAX-1]) {
+               if (ret) 
+                       range->max.tcp.port = range->min.tcp.port;
+       } else {
+               ret = 1;
+               range->max.tcp.port = 
+                       *(u_int16_t *)NFA_DATA(tb[CTA_PROTONAT_PORT_MAX-1]);
+       }
+
+       return ret;
+}
+#endif
+
 int __init ip_nat_init(void)
 {
        size_t i;
@@ -509,13 +593,13 @@ int __init ip_nat_init(void)
                return -ENOMEM;
 
        /* Sew in builtin protocols. */
-       WRITE_LOCK(&ip_nat_lock);
+       write_lock_bh(&ip_nat_lock);
        for (i = 0; i < MAX_IP_NAT_PROTO; i++)
                ip_nat_protos[i] = &ip_nat_unknown_protocol;
        ip_nat_protos[IPPROTO_TCP] = &ip_nat_protocol_tcp;
        ip_nat_protos[IPPROTO_UDP] = &ip_nat_protocol_udp;
        ip_nat_protos[IPPROTO_ICMP] = &ip_nat_protocol_icmp;
-       WRITE_UNLOCK(&ip_nat_lock);
+       write_unlock_bh(&ip_nat_lock);
 
        for (i = 0; i < ip_nat_htable_size; i++) {
                INIT_LIST_HEAD(&bysource[i]);