]> err.no Git - linux-2.6/blob - drivers/infiniband/core/user_mad.c
[IB] ib_umad: fix crash when freeing send buffers
[linux-2.6] / drivers / infiniband / core / user_mad.c
1 /*
2  * Copyright (c) 2004 Topspin Communications.  All rights reserved.
3  * Copyright (c) 2005 Voltaire, Inc. All rights reserved. 
4  * Copyright (c) 2005 Sun Microsystems, Inc. All rights reserved.
5  *
6  * This software is available to you under a choice of one of two
7  * licenses.  You may choose to be licensed under the terms of the GNU
8  * General Public License (GPL) Version 2, available from the file
9  * COPYING in the main directory of this source tree, or the
10  * OpenIB.org BSD license below:
11  *
12  *     Redistribution and use in source and binary forms, with or
13  *     without modification, are permitted provided that the following
14  *     conditions are met:
15  *
16  *      - Redistributions of source code must retain the above
17  *        copyright notice, this list of conditions and the following
18  *        disclaimer.
19  *
20  *      - Redistributions in binary form must reproduce the above
21  *        copyright notice, this list of conditions and the following
22  *        disclaimer in the documentation and/or other materials
23  *        provided with the distribution.
24  *
25  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
26  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
27  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
28  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
29  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
30  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
31  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
32  * SOFTWARE.
33  *
34  * $Id: user_mad.c 2814 2005-07-06 19:14:09Z halr $
35  */
36
37 #include <linux/module.h>
38 #include <linux/init.h>
39 #include <linux/device.h>
40 #include <linux/err.h>
41 #include <linux/fs.h>
42 #include <linux/cdev.h>
43 #include <linux/pci.h>
44 #include <linux/dma-mapping.h>
45 #include <linux/poll.h>
46 #include <linux/rwsem.h>
47 #include <linux/kref.h>
48
49 #include <asm/uaccess.h>
50 #include <asm/semaphore.h>
51
52 #include <rdma/ib_mad.h>
53 #include <rdma/ib_user_mad.h>
54
55 MODULE_AUTHOR("Roland Dreier");
56 MODULE_DESCRIPTION("InfiniBand userspace MAD packet access");
57 MODULE_LICENSE("Dual BSD/GPL");
58
59 enum {
60         IB_UMAD_MAX_PORTS  = 64,
61         IB_UMAD_MAX_AGENTS = 32,
62
63         IB_UMAD_MAJOR      = 231,
64         IB_UMAD_MINOR_BASE = 0
65 };
66
67 struct ib_umad_port {
68         int                    devnum;
69         struct cdev            dev;
70         struct class_device    class_dev;
71
72         int                    sm_devnum;
73         struct cdev            sm_dev;
74         struct class_device    sm_class_dev;
75         struct semaphore       sm_sem;
76
77         struct ib_device      *ib_dev;
78         struct ib_umad_device *umad_dev;
79         u8                     port_num;
80 };
81
82 struct ib_umad_device {
83         int                  start_port, end_port;
84         struct kref          ref;
85         struct ib_umad_port  port[0];
86 };
87
88 struct ib_umad_file {
89         struct ib_umad_port *port;
90         spinlock_t           recv_lock;
91         struct list_head     recv_list;
92         wait_queue_head_t    recv_wait;
93         struct rw_semaphore  agent_mutex;
94         struct ib_mad_agent *agent[IB_UMAD_MAX_AGENTS];
95         struct ib_mr        *mr[IB_UMAD_MAX_AGENTS];
96 };
97
98 struct ib_umad_packet {
99         struct ib_mad_send_buf *msg;
100         struct list_head   list;
101         int                length;
102         DECLARE_PCI_UNMAP_ADDR(mapping)
103         struct ib_user_mad mad;
104 };
105
106 static const dev_t base_dev = MKDEV(IB_UMAD_MAJOR, IB_UMAD_MINOR_BASE);
107 static spinlock_t map_lock;
108 static DECLARE_BITMAP(dev_map, IB_UMAD_MAX_PORTS * 2);
109
110 static void ib_umad_add_one(struct ib_device *device);
111 static void ib_umad_remove_one(struct ib_device *device);
112
113 static int queue_packet(struct ib_umad_file *file,
114                         struct ib_mad_agent *agent,
115                         struct ib_umad_packet *packet)
116 {
117         int ret = 1;
118
119         down_read(&file->agent_mutex);
120         for (packet->mad.hdr.id = 0;
121              packet->mad.hdr.id < IB_UMAD_MAX_AGENTS;
122              packet->mad.hdr.id++)
123                 if (agent == file->agent[packet->mad.hdr.id]) {
124                         spin_lock_irq(&file->recv_lock);
125                         list_add_tail(&packet->list, &file->recv_list);
126                         spin_unlock_irq(&file->recv_lock);
127                         wake_up_interruptible(&file->recv_wait);
128                         ret = 0;
129                         break;
130                 }
131
132         up_read(&file->agent_mutex);
133
134         return ret;
135 }
136
137 static void send_handler(struct ib_mad_agent *agent,
138                          struct ib_mad_send_wc *send_wc)
139 {
140         struct ib_umad_file *file = agent->context;
141         struct ib_umad_packet *timeout;
142         struct ib_umad_packet *packet = send_wc->send_buf->context[0];
143
144         ib_destroy_ah(packet->msg->ah);
145         ib_free_send_mad(packet->msg);
146
147         if (send_wc->status == IB_WC_RESP_TIMEOUT_ERR) {
148                 timeout = kmalloc(sizeof *timeout + sizeof (struct ib_mad_hdr),
149                                   GFP_KERNEL);
150                 if (!timeout)
151                         goto out;
152
153                 memset(timeout, 0, sizeof *timeout + sizeof (struct ib_mad_hdr));
154
155                 timeout->length = sizeof (struct ib_mad_hdr);
156                 timeout->mad.hdr.id = packet->mad.hdr.id;
157                 timeout->mad.hdr.status = ETIMEDOUT;
158                 memcpy(timeout->mad.data, packet->mad.data,
159                        sizeof (struct ib_mad_hdr));
160
161                 if (!queue_packet(file, agent, timeout))
162                                 return;
163         }
164 out:
165         kfree(packet);
166 }
167
168 static void recv_handler(struct ib_mad_agent *agent,
169                          struct ib_mad_recv_wc *mad_recv_wc)
170 {
171         struct ib_umad_file *file = agent->context;
172         struct ib_umad_packet *packet;
173         int length;
174
175         if (mad_recv_wc->wc->status != IB_WC_SUCCESS)
176                 goto out;
177
178         length = mad_recv_wc->mad_len;
179         packet = kmalloc(sizeof *packet + length, GFP_KERNEL);
180         if (!packet)
181                 goto out;
182
183         memset(packet, 0, sizeof *packet + length);
184         packet->length = length;
185
186         ib_coalesce_recv_mad(mad_recv_wc, packet->mad.data);
187
188         packet->mad.hdr.status    = 0;
189         packet->mad.hdr.length    = length + sizeof (struct ib_user_mad);
190         packet->mad.hdr.qpn       = cpu_to_be32(mad_recv_wc->wc->src_qp);
191         packet->mad.hdr.lid       = cpu_to_be16(mad_recv_wc->wc->slid);
192         packet->mad.hdr.sl        = mad_recv_wc->wc->sl;
193         packet->mad.hdr.path_bits = mad_recv_wc->wc->dlid_path_bits;
194         packet->mad.hdr.grh_present = !!(mad_recv_wc->wc->wc_flags & IB_WC_GRH);
195         if (packet->mad.hdr.grh_present) {
196                 /* XXX parse GRH */
197                 packet->mad.hdr.gid_index       = 0;
198                 packet->mad.hdr.hop_limit       = 0;
199                 packet->mad.hdr.traffic_class   = 0;
200                 memset(packet->mad.hdr.gid, 0, 16);
201                 packet->mad.hdr.flow_label      = 0;
202         }
203
204         if (queue_packet(file, agent, packet))
205                 kfree(packet);
206
207 out:
208         ib_free_recv_mad(mad_recv_wc);
209 }
210
211 static ssize_t ib_umad_read(struct file *filp, char __user *buf,
212                             size_t count, loff_t *pos)
213 {
214         struct ib_umad_file *file = filp->private_data;
215         struct ib_umad_packet *packet;
216         ssize_t ret;
217
218         if (count < sizeof (struct ib_user_mad) + sizeof (struct ib_mad))
219                 return -EINVAL;
220
221         spin_lock_irq(&file->recv_lock);
222
223         while (list_empty(&file->recv_list)) {
224                 spin_unlock_irq(&file->recv_lock);
225
226                 if (filp->f_flags & O_NONBLOCK)
227                         return -EAGAIN;
228
229                 if (wait_event_interruptible(file->recv_wait,
230                                              !list_empty(&file->recv_list)))
231                         return -ERESTARTSYS;
232
233                 spin_lock_irq(&file->recv_lock);
234         }
235
236         packet = list_entry(file->recv_list.next, struct ib_umad_packet, list);
237         list_del(&packet->list);
238
239         spin_unlock_irq(&file->recv_lock);
240
241         if (count < packet->length + sizeof (struct ib_user_mad)) {
242                 /* Return length needed (and first RMPP segment) if too small */
243                 if (copy_to_user(buf, &packet->mad,
244                                  sizeof (struct ib_user_mad) + sizeof (struct ib_mad)))
245                         ret = -EFAULT;
246                 else
247                         ret = -ENOSPC;
248         } else if (copy_to_user(buf, &packet->mad,
249                               packet->length + sizeof (struct ib_user_mad)))
250                 ret = -EFAULT;
251         else
252                 ret = packet->length + sizeof (struct ib_user_mad);
253         if (ret < 0) {
254                 /* Requeue packet */
255                 spin_lock_irq(&file->recv_lock);
256                 list_add(&packet->list, &file->recv_list);
257                 spin_unlock_irq(&file->recv_lock);
258         } else
259                 kfree(packet);
260         return ret;
261 }
262
263 static ssize_t ib_umad_write(struct file *filp, const char __user *buf,
264                              size_t count, loff_t *pos)
265 {
266         struct ib_umad_file *file = filp->private_data;
267         struct ib_umad_packet *packet;
268         struct ib_mad_agent *agent;
269         struct ib_ah_attr ah_attr;
270         struct ib_ah *ah;
271         struct ib_rmpp_mad *rmpp_mad;
272         u8 method;
273         __be64 *tid;
274         int ret, length, hdr_len, rmpp_hdr_size;
275         int rmpp_active = 0;
276
277         if (count < sizeof (struct ib_user_mad))
278                 return -EINVAL;
279
280         length = count - sizeof (struct ib_user_mad);
281         packet = kmalloc(sizeof *packet + sizeof(struct ib_mad_hdr) +
282                          sizeof (struct ib_rmpp_hdr), GFP_KERNEL);
283         if (!packet)
284                 return -ENOMEM;
285
286         if (copy_from_user(&packet->mad, buf,
287                             sizeof (struct ib_user_mad) +
288                             sizeof (struct ib_mad_hdr) +
289                             sizeof (struct ib_rmpp_hdr))) {
290                 ret = -EFAULT;
291                 goto err;
292         }
293
294         if (packet->mad.hdr.id < 0 ||
295             packet->mad.hdr.id >= IB_UMAD_MAX_AGENTS) {
296                 ret = -EINVAL;
297                 goto err;
298         }
299
300         packet->length = length;
301
302         down_read(&file->agent_mutex);
303
304         agent = file->agent[packet->mad.hdr.id];
305         if (!agent) {
306                 ret = -EINVAL;
307                 goto err_up;
308         }
309
310         memset(&ah_attr, 0, sizeof ah_attr);
311         ah_attr.dlid          = be16_to_cpu(packet->mad.hdr.lid);
312         ah_attr.sl            = packet->mad.hdr.sl;
313         ah_attr.src_path_bits = packet->mad.hdr.path_bits;
314         ah_attr.port_num      = file->port->port_num;
315         if (packet->mad.hdr.grh_present) {
316                 ah_attr.ah_flags = IB_AH_GRH;
317                 memcpy(ah_attr.grh.dgid.raw, packet->mad.hdr.gid, 16);
318                 ah_attr.grh.flow_label     = be32_to_cpu(packet->mad.hdr.flow_label);
319                 ah_attr.grh.hop_limit      = packet->mad.hdr.hop_limit;
320                 ah_attr.grh.traffic_class  = packet->mad.hdr.traffic_class;
321         }
322
323         ah = ib_create_ah(agent->qp->pd, &ah_attr);
324         if (IS_ERR(ah)) {
325                 ret = PTR_ERR(ah);
326                 goto err_up;
327         }
328
329         rmpp_mad = (struct ib_rmpp_mad *) packet->mad.data;
330         if (ib_get_rmpp_flags(&rmpp_mad->rmpp_hdr) & IB_MGMT_RMPP_FLAG_ACTIVE) {
331                 /* RMPP active */
332                 if (!agent->rmpp_version) {
333                         ret = -EINVAL;
334                         goto err_ah;
335                 }
336
337                 /* Validate that the management class can support RMPP */
338                 if (rmpp_mad->mad_hdr.mgmt_class == IB_MGMT_CLASS_SUBN_ADM) {
339                         hdr_len = IB_MGMT_SA_HDR;
340                 } else if ((rmpp_mad->mad_hdr.mgmt_class >= IB_MGMT_CLASS_VENDOR_RANGE2_START) &&
341                             (rmpp_mad->mad_hdr.mgmt_class <= IB_MGMT_CLASS_VENDOR_RANGE2_END)) {
342                                 hdr_len = IB_MGMT_VENDOR_HDR;
343                 } else {
344                         ret = -EINVAL;
345                         goto err_ah;
346                 }
347                 rmpp_active = 1;
348         } else {
349                 if (length > sizeof (struct ib_mad)) {
350                         ret = -EINVAL;
351                         goto err_ah;
352                 }
353                 hdr_len = IB_MGMT_MAD_HDR;
354         }
355
356         packet->msg = ib_create_send_mad(agent,
357                                          be32_to_cpu(packet->mad.hdr.qpn),
358                                          0, rmpp_active,
359                                          hdr_len, length - hdr_len,
360                                          GFP_KERNEL);
361         if (IS_ERR(packet->msg)) {
362                 ret = PTR_ERR(packet->msg);
363                 goto err_ah;
364         }
365
366         packet->msg->ah         = ah;
367         packet->msg->timeout_ms = packet->mad.hdr.timeout_ms;
368         packet->msg->retries    = packet->mad.hdr.retries;
369         packet->msg->context[0] = packet;
370
371         if (!rmpp_active) {
372                 /* Copy message from user into send buffer */
373                 if (copy_from_user(packet->msg->mad,
374                                    buf + sizeof (struct ib_user_mad), length)) {
375                         ret = -EFAULT;
376                         goto err_msg;
377                 }
378         } else {
379                 rmpp_hdr_size = sizeof (struct ib_mad_hdr) +
380                                 sizeof (struct ib_rmpp_hdr);
381
382                 /* Only copy MAD headers (RMPP header in place) */
383                 memcpy(packet->msg->mad, packet->mad.data,
384                        sizeof (struct ib_mad_hdr));
385
386                 /* Now, copy rest of message from user into send buffer */
387                 if (copy_from_user(((struct ib_rmpp_mad *) packet->msg->mad)->data,
388                                    buf + sizeof (struct ib_user_mad) + rmpp_hdr_size,
389                                    length - rmpp_hdr_size)) {
390                         ret = -EFAULT;
391                         goto err_msg;
392                 }
393         }
394
395         /*
396          * If userspace is generating a request that will generate a
397          * response, we need to make sure the high-order part of the
398          * transaction ID matches the agent being used to send the
399          * MAD.
400          */
401         method = ((struct ib_mad_hdr *) packet->msg->mad)->method;
402
403         if (!(method & IB_MGMT_METHOD_RESP)       &&
404             method != IB_MGMT_METHOD_TRAP_REPRESS &&
405             method != IB_MGMT_METHOD_SEND) {
406                 tid = &((struct ib_mad_hdr *) packet->msg->mad)->tid;
407                 *tid = cpu_to_be64(((u64) agent->hi_tid) << 32 |
408                                    (be64_to_cpup(tid) & 0xffffffff));
409         }
410
411         ret = ib_post_send_mad(packet->msg, NULL);
412         if (ret)
413                 goto err_msg;
414
415         up_read(&file->agent_mutex);
416
417         return sizeof (struct ib_user_mad_hdr) + packet->length;
418
419 err_msg:
420         ib_free_send_mad(packet->msg);
421
422 err_ah:
423         ib_destroy_ah(ah);
424
425 err_up:
426         up_read(&file->agent_mutex);
427
428 err:
429         kfree(packet);
430         return ret;
431 }
432
433 static unsigned int ib_umad_poll(struct file *filp, struct poll_table_struct *wait)
434 {
435         struct ib_umad_file *file = filp->private_data;
436
437         /* we will always be able to post a MAD send */
438         unsigned int mask = POLLOUT | POLLWRNORM;
439
440         poll_wait(filp, &file->recv_wait, wait);
441
442         if (!list_empty(&file->recv_list))
443                 mask |= POLLIN | POLLRDNORM;
444
445         return mask;
446 }
447
448 static int ib_umad_reg_agent(struct ib_umad_file *file, unsigned long arg)
449 {
450         struct ib_user_mad_reg_req ureq;
451         struct ib_mad_reg_req req;
452         struct ib_mad_agent *agent;
453         int agent_id;
454         int ret;
455
456         down_write(&file->agent_mutex);
457
458         if (copy_from_user(&ureq, (void __user *) arg, sizeof ureq)) {
459                 ret = -EFAULT;
460                 goto out;
461         }
462
463         if (ureq.qpn != 0 && ureq.qpn != 1) {
464                 ret = -EINVAL;
465                 goto out;
466         }
467
468         for (agent_id = 0; agent_id < IB_UMAD_MAX_AGENTS; ++agent_id)
469                 if (!file->agent[agent_id])
470                         goto found;
471
472         ret = -ENOMEM;
473         goto out;
474
475 found:
476         if (ureq.mgmt_class) {
477                 req.mgmt_class         = ureq.mgmt_class;
478                 req.mgmt_class_version = ureq.mgmt_class_version;
479                 memcpy(req.method_mask, ureq.method_mask, sizeof req.method_mask);
480                 memcpy(req.oui,         ureq.oui,         sizeof req.oui);
481         }
482
483         agent = ib_register_mad_agent(file->port->ib_dev, file->port->port_num,
484                                       ureq.qpn ? IB_QPT_GSI : IB_QPT_SMI,
485                                       ureq.mgmt_class ? &req : NULL,
486                                       ureq.rmpp_version,
487                                       send_handler, recv_handler, file);
488         if (IS_ERR(agent)) {
489                 ret = PTR_ERR(agent);
490                 goto out;
491         }
492
493         file->agent[agent_id] = agent;
494
495         file->mr[agent_id] = ib_get_dma_mr(agent->qp->pd, IB_ACCESS_LOCAL_WRITE);
496         if (IS_ERR(file->mr[agent_id])) {
497                 ret = -ENOMEM;
498                 goto err;
499         }
500
501         if (put_user(agent_id,
502                      (u32 __user *) (arg + offsetof(struct ib_user_mad_reg_req, id)))) {
503                 ret = -EFAULT;
504                 goto err_mr;
505         }
506
507         ret = 0;
508         goto out;
509
510 err_mr:
511         ib_dereg_mr(file->mr[agent_id]);
512
513 err:
514         file->agent[agent_id] = NULL;
515         ib_unregister_mad_agent(agent);
516
517 out:
518         up_write(&file->agent_mutex);
519         return ret;
520 }
521
522 static int ib_umad_unreg_agent(struct ib_umad_file *file, unsigned long arg)
523 {
524         u32 id;
525         int ret = 0;
526
527         down_write(&file->agent_mutex);
528
529         if (get_user(id, (u32 __user *) arg)) {
530                 ret = -EFAULT;
531                 goto out;
532         }
533
534         if (id < 0 || id >= IB_UMAD_MAX_AGENTS || !file->agent[id]) {
535                 ret = -EINVAL;
536                 goto out;
537         }
538
539         ib_dereg_mr(file->mr[id]);
540         ib_unregister_mad_agent(file->agent[id]);
541         file->agent[id] = NULL;
542
543 out:
544         up_write(&file->agent_mutex);
545         return ret;
546 }
547
548 static long ib_umad_ioctl(struct file *filp, unsigned int cmd,
549                           unsigned long arg)
550 {
551         switch (cmd) {
552         case IB_USER_MAD_REGISTER_AGENT:
553                 return ib_umad_reg_agent(filp->private_data, arg);
554         case IB_USER_MAD_UNREGISTER_AGENT:
555                 return ib_umad_unreg_agent(filp->private_data, arg);
556         default:
557                 return -ENOIOCTLCMD;
558         }
559 }
560
561 static int ib_umad_open(struct inode *inode, struct file *filp)
562 {
563         struct ib_umad_port *port =
564                 container_of(inode->i_cdev, struct ib_umad_port, dev);
565         struct ib_umad_file *file;
566
567         file = kmalloc(sizeof *file, GFP_KERNEL);
568         if (!file)
569                 return -ENOMEM;
570
571         memset(file, 0, sizeof *file);
572
573         spin_lock_init(&file->recv_lock);
574         init_rwsem(&file->agent_mutex);
575         INIT_LIST_HEAD(&file->recv_list);
576         init_waitqueue_head(&file->recv_wait);
577
578         file->port = port;
579         filp->private_data = file;
580
581         return 0;
582 }
583
584 static int ib_umad_close(struct inode *inode, struct file *filp)
585 {
586         struct ib_umad_file *file = filp->private_data;
587         struct ib_umad_packet *packet, *tmp;
588         int i;
589
590         for (i = 0; i < IB_UMAD_MAX_AGENTS; ++i)
591                 if (file->agent[i]) {
592                         ib_dereg_mr(file->mr[i]);
593                         ib_unregister_mad_agent(file->agent[i]);
594                 }
595
596         list_for_each_entry_safe(packet, tmp, &file->recv_list, list)
597                 kfree(packet);
598
599         kfree(file);
600
601         return 0;
602 }
603
604 static struct file_operations umad_fops = {
605         .owner          = THIS_MODULE,
606         .read           = ib_umad_read,
607         .write          = ib_umad_write,
608         .poll           = ib_umad_poll,
609         .unlocked_ioctl = ib_umad_ioctl,
610         .compat_ioctl   = ib_umad_ioctl,
611         .open           = ib_umad_open,
612         .release        = ib_umad_close
613 };
614
615 static int ib_umad_sm_open(struct inode *inode, struct file *filp)
616 {
617         struct ib_umad_port *port =
618                 container_of(inode->i_cdev, struct ib_umad_port, sm_dev);
619         struct ib_port_modify props = {
620                 .set_port_cap_mask = IB_PORT_SM
621         };
622         int ret;
623
624         if (filp->f_flags & O_NONBLOCK) {
625                 if (down_trylock(&port->sm_sem))
626                         return -EAGAIN;
627         } else {
628                 if (down_interruptible(&port->sm_sem))
629                         return -ERESTARTSYS;
630         }
631
632         ret = ib_modify_port(port->ib_dev, port->port_num, 0, &props);
633         if (ret) {
634                 up(&port->sm_sem);
635                 return ret;
636         }
637
638         filp->private_data = port;
639
640         return 0;
641 }
642
643 static int ib_umad_sm_close(struct inode *inode, struct file *filp)
644 {
645         struct ib_umad_port *port = filp->private_data;
646         struct ib_port_modify props = {
647                 .clr_port_cap_mask = IB_PORT_SM
648         };
649         int ret;
650
651         ret = ib_modify_port(port->ib_dev, port->port_num, 0, &props);
652         up(&port->sm_sem);
653
654         return ret;
655 }
656
657 static struct file_operations umad_sm_fops = {
658         .owner   = THIS_MODULE,
659         .open    = ib_umad_sm_open,
660         .release = ib_umad_sm_close
661 };
662
663 static struct ib_client umad_client = {
664         .name   = "umad",
665         .add    = ib_umad_add_one,
666         .remove = ib_umad_remove_one
667 };
668
669 static ssize_t show_ibdev(struct class_device *class_dev, char *buf)
670 {
671         struct ib_umad_port *port = class_get_devdata(class_dev);
672
673         return sprintf(buf, "%s\n", port->ib_dev->name);
674 }
675 static CLASS_DEVICE_ATTR(ibdev, S_IRUGO, show_ibdev, NULL);
676
677 static ssize_t show_port(struct class_device *class_dev, char *buf)
678 {
679         struct ib_umad_port *port = class_get_devdata(class_dev);
680
681         return sprintf(buf, "%d\n", port->port_num);
682 }
683 static CLASS_DEVICE_ATTR(port, S_IRUGO, show_port, NULL);
684
685 static void ib_umad_release_dev(struct kref *ref)
686 {
687         struct ib_umad_device *dev =
688                 container_of(ref, struct ib_umad_device, ref);
689
690         kfree(dev);
691 }
692
693 static void ib_umad_release_port(struct class_device *class_dev)
694 {
695         struct ib_umad_port *port = class_get_devdata(class_dev);
696
697         if (class_dev == &port->class_dev) {
698                 cdev_del(&port->dev);
699                 clear_bit(port->devnum, dev_map);
700         } else {
701                 cdev_del(&port->sm_dev);
702                 clear_bit(port->sm_devnum, dev_map);
703         }
704
705         kref_put(&port->umad_dev->ref, ib_umad_release_dev);
706 }
707
708 static struct class umad_class = {
709         .name    = "infiniband_mad",
710         .release = ib_umad_release_port
711 };
712
713 static ssize_t show_abi_version(struct class *class, char *buf)
714 {
715         return sprintf(buf, "%d\n", IB_USER_MAD_ABI_VERSION);
716 }
717 static CLASS_ATTR(abi_version, S_IRUGO, show_abi_version, NULL);
718
719 static int ib_umad_init_port(struct ib_device *device, int port_num,
720                              struct ib_umad_port *port)
721 {
722         spin_lock(&map_lock);
723         port->devnum = find_first_zero_bit(dev_map, IB_UMAD_MAX_PORTS);
724         if (port->devnum >= IB_UMAD_MAX_PORTS) {
725                 spin_unlock(&map_lock);
726                 return -1;
727         }
728         port->sm_devnum = find_next_zero_bit(dev_map, IB_UMAD_MAX_PORTS * 2, IB_UMAD_MAX_PORTS);
729         if (port->sm_devnum >= IB_UMAD_MAX_PORTS * 2) {
730                 spin_unlock(&map_lock);
731                 return -1;
732         }
733         set_bit(port->devnum, dev_map);
734         set_bit(port->sm_devnum, dev_map);
735         spin_unlock(&map_lock);
736
737         port->ib_dev   = device;
738         port->port_num = port_num;
739         init_MUTEX(&port->sm_sem);
740
741         cdev_init(&port->dev, &umad_fops);
742         port->dev.owner = THIS_MODULE;
743         kobject_set_name(&port->dev.kobj, "umad%d", port->devnum);
744         if (cdev_add(&port->dev, base_dev + port->devnum, 1))
745                 return -1;
746
747         port->class_dev.class = &umad_class;
748         port->class_dev.dev   = device->dma_device;
749         port->class_dev.devt  = port->dev.dev;
750
751         snprintf(port->class_dev.class_id, BUS_ID_SIZE, "umad%d", port->devnum);
752
753         if (class_device_register(&port->class_dev))
754                 goto err_cdev;
755
756         class_set_devdata(&port->class_dev, port);
757         kref_get(&port->umad_dev->ref);
758
759         if (class_device_create_file(&port->class_dev, &class_device_attr_ibdev))
760                 goto err_class;
761         if (class_device_create_file(&port->class_dev, &class_device_attr_port))
762                 goto err_class;
763
764         cdev_init(&port->sm_dev, &umad_sm_fops);
765         port->sm_dev.owner = THIS_MODULE;
766         kobject_set_name(&port->dev.kobj, "issm%d", port->sm_devnum - IB_UMAD_MAX_PORTS);
767         if (cdev_add(&port->sm_dev, base_dev + port->sm_devnum, 1))
768                 return -1;
769
770         port->sm_class_dev.class = &umad_class;
771         port->sm_class_dev.dev   = device->dma_device;
772         port->sm_class_dev.devt  = port->sm_dev.dev;
773
774         snprintf(port->sm_class_dev.class_id, BUS_ID_SIZE, "issm%d", port->sm_devnum - IB_UMAD_MAX_PORTS);
775
776         if (class_device_register(&port->sm_class_dev))
777                 goto err_sm_cdev;
778
779         class_set_devdata(&port->sm_class_dev, port);
780         kref_get(&port->umad_dev->ref);
781
782         if (class_device_create_file(&port->sm_class_dev, &class_device_attr_ibdev))
783                 goto err_sm_class;
784         if (class_device_create_file(&port->sm_class_dev, &class_device_attr_port))
785                 goto err_sm_class;
786
787         return 0;
788
789 err_sm_class:
790         class_device_unregister(&port->sm_class_dev);
791
792 err_sm_cdev:
793         cdev_del(&port->sm_dev);
794
795 err_class:
796         class_device_unregister(&port->class_dev);
797
798 err_cdev:
799         cdev_del(&port->dev);
800         clear_bit(port->devnum, dev_map);
801
802         return -1;
803 }
804
805 static void ib_umad_add_one(struct ib_device *device)
806 {
807         struct ib_umad_device *umad_dev;
808         int s, e, i;
809
810         if (device->node_type == IB_NODE_SWITCH)
811                 s = e = 0;
812         else {
813                 s = 1;
814                 e = device->phys_port_cnt;
815         }
816
817         umad_dev = kmalloc(sizeof *umad_dev +
818                            (e - s + 1) * sizeof (struct ib_umad_port),
819                            GFP_KERNEL);
820         if (!umad_dev)
821                 return;
822
823         memset(umad_dev, 0, sizeof *umad_dev +
824                (e - s + 1) * sizeof (struct ib_umad_port));
825
826         kref_init(&umad_dev->ref);
827
828         umad_dev->start_port = s;
829         umad_dev->end_port   = e;
830
831         for (i = s; i <= e; ++i) {
832                 umad_dev->port[i - s].umad_dev = umad_dev;
833
834                 if (ib_umad_init_port(device, i, &umad_dev->port[i - s]))
835                         goto err;
836         }
837
838         ib_set_client_data(device, &umad_client, umad_dev);
839
840         return;
841
842 err:
843         while (--i >= s) {
844                 class_device_unregister(&umad_dev->port[i - s].class_dev);
845                 class_device_unregister(&umad_dev->port[i - s].sm_class_dev);
846         }
847
848         kref_put(&umad_dev->ref, ib_umad_release_dev);
849 }
850
851 static void ib_umad_remove_one(struct ib_device *device)
852 {
853         struct ib_umad_device *umad_dev = ib_get_client_data(device, &umad_client);
854         int i;
855
856         if (!umad_dev)
857                 return;
858
859         for (i = 0; i <= umad_dev->end_port - umad_dev->start_port; ++i) {
860                 class_device_unregister(&umad_dev->port[i].class_dev);
861                 class_device_unregister(&umad_dev->port[i].sm_class_dev);
862         }
863
864         kref_put(&umad_dev->ref, ib_umad_release_dev);
865 }
866
867 static int __init ib_umad_init(void)
868 {
869         int ret;
870
871         spin_lock_init(&map_lock);
872
873         ret = register_chrdev_region(base_dev, IB_UMAD_MAX_PORTS * 2,
874                                      "infiniband_mad");
875         if (ret) {
876                 printk(KERN_ERR "user_mad: couldn't register device number\n");
877                 goto out;
878         }
879
880         ret = class_register(&umad_class);
881         if (ret) {
882                 printk(KERN_ERR "user_mad: couldn't create class infiniband_mad\n");
883                 goto out_chrdev;
884         }
885
886         ret = class_create_file(&umad_class, &class_attr_abi_version);
887         if (ret) {
888                 printk(KERN_ERR "user_mad: couldn't create abi_version attribute\n");
889                 goto out_class;
890         }
891
892         ret = ib_register_client(&umad_client);
893         if (ret) {
894                 printk(KERN_ERR "user_mad: couldn't register ib_umad client\n");
895                 goto out_class;
896         }
897
898         return 0;
899
900 out_class:
901         class_unregister(&umad_class);
902
903 out_chrdev:
904         unregister_chrdev_region(base_dev, IB_UMAD_MAX_PORTS * 2);
905
906 out:
907         return ret;
908 }
909
910 static void __exit ib_umad_cleanup(void)
911 {
912         ib_unregister_client(&umad_client);
913         class_unregister(&umad_class);
914         unregister_chrdev_region(base_dev, IB_UMAD_MAX_PORTS * 2);
915 }
916
917 module_init(ib_umad_init);
918 module_exit(ib_umad_cleanup);