diff options
author | Miloslav Trmač <mitr@redhat.com> | 2010-09-13 22:50:20 +0200 |
---|---|---|
committer | Miloslav Trmač <mitr@redhat.com> | 2010-09-13 23:13:14 +0200 |
commit | b6d45c17d91dadaeed30131d8e8389cf33678747 (patch) | |
tree | 1f5c97cb182e2935b4b2ec5200cfb6cd4b6184c8 | |
parent | 244f7c711cdbd3f9ab35050dc7bc623a54d1dfa0 (diff) | |
download | cryptodev-linux-b6d45c17d91dadaeed30131d8e8389cf33678747.tar.gz cryptodev-linux-b6d45c17d91dadaeed30131d8e8389cf33678747.tar.xz cryptodev-linux-b6d45c17d91dadaeed30131d8e8389cf33678747.zip |
Use separate protocols for master and slave sockets
-rw-r--r-- | cryptodev_main.c | 433 | ||||
-rw-r--r-- | tests/ncr.c | 34 |
2 files changed, 329 insertions, 138 deletions
diff --git a/cryptodev_main.c b/cryptodev_main.c index 733e9f1..49e362b 100644 --- a/cryptodev_main.c +++ b/cryptodev_main.c @@ -94,33 +94,95 @@ int __get_userbuf(uint8_t __user * addr, uint32_t len, int write, /* ====== /dev/crypto ====== */ +/* A socket merely used for I/O. */ +struct data_sock { + struct sock sk; + struct alg_sock *master; /* See below for reference counting */ + unsigned index; +}; + +/* A socket holding a crypto_tfm. */ struct alg_sock { struct sock sk; struct sockaddr_alg addr; - struct sock *queued; + struct data_sock *slaves[1]; + unsigned num_slaves, accept_idx; struct hash_data hash; }; +/* Socket reference counting: + The primary reference to alg_sock is from user-space. + + After listen(), data_sock instances are referenced through alg_sock::slaves. + Unaccepted data_sock instancess do not not hold a reference back to alg_sock. + + accept() adds an alg_sock reference through data_sock::master. + + This creates a reference loop alg_sock<->data_sock. The loop is broken on + last close() of any side of the pair, i.e. proto_ops::release, when + the reference is dropped. */ + +static struct data_sock *data_sk(struct sock *sk) +{ + return container_of(sk, struct data_sock, sk); +} + static struct alg_sock *alg_sk(struct sock *sk) { return container_of(sk, struct alg_sock, sk); } -static struct proto alg_proto = { -/* void (*close)(struct sock *sk, */ -/* long timeout); */ -/* int (*connect)(struct sock *sk, */ -/* struct sockaddr *uaddr, */ -/* int addr_len); */ -/* int (*disconnect)(struct sock *sk, int flags); */ +static struct proto data_proto = { +/* int (*init)(struct sock *sk); */ +/* void (*destroy)(struct sock *sk); */ +/* int (*setsockopt)(struct sock *sk, int level, */ +/* int optname, char __user *optval, */ +/* unsigned int optlen); */ +/* int (*getsockopt)(struct sock *sk, int level, */ +/* int optname, char __user *optval, */ +/* int __user *option); */ +/* #ifdef CONFIG_COMPAT */ +/* int (*compat_setsockopt)(struct sock *sk, */ +/* int level, */ +/* int optname, char __user *optval, */ +/* unsigned int optlen); */ +/* int (*compat_getsockopt)(struct sock *sk, */ +/* int level, */ +/* int optname, char __user *optval, */ +/* int __user *option); */ +/* #endif */ +/* int (*recvmsg)(struct kiocb *iocb, struct sock *sk, */ +/* struct msghdr *msg, */ +/* size_t len, int noblock, int flags, */ +/* int *addr_len); */ -/* struct sock * (*accept) (struct sock *sk, int flags, int *err); */ +/* /\* Memory pressure *\/ */ +/* atomic_t *memory_allocated; /\* Current allocated memory. *\/ */ +/* struct percpu_counter *sockets_allocated; /\* Current number of sockets. *\/ */ +/* /\* */ +/* * Pressure flag: try to collapse. */ +/* * Technical note: it is used by multiple contexts non atomically. */ +/* * All the __sk_mem_schedule() is of this nature: accounting */ +/* * is strict, actions are advisory and have some latency. */ +/* *\/ */ +/* int *memory_pressure; */ +/* int *sysctl_mem; */ +/* int *sysctl_wmem; */ +/* int *sysctl_rmem; */ +/* int max_header; */ -/* int (*ioctl)(struct sock *sk, int cmd, */ -/* unsigned long arg); */ + .obj_size = sizeof(struct data_sock), +/* int slab_flags; */ + +/* struct percpu_counter *orphan_count; */ + + .owner = THIS_MODULE, + .name = "ALG-data", +}; + +static struct proto alg_proto = { /* int (*init)(struct sock *sk); */ /* void (*destroy)(struct sock *sk); */ -/* void (*shutdown)(struct sock *sk, int how); */ /* int (*setsockopt)(struct sock *sk, int level, */ /* int optname, char __user *optval, */ /* unsigned int optlen); */ @@ -137,32 +199,12 @@ static struct proto alg_proto = { /* int optname, char __user *optval, */ /* int __user *option); */ /* #endif */ -/* int (*sendmsg)(struct kiocb *iocb, struct sock *sk, */ -/* struct msghdr *msg, size_t len); */ /* int (*recvmsg)(struct kiocb *iocb, struct sock *sk, */ /* struct msghdr *msg, */ /* size_t len, int noblock, int flags, */ /* int *addr_len); */ -/* int (*sendpage)(struct sock *sk, struct page *page, */ -/* int offset, size_t size, int flags); */ -/* int (*bind)(struct sock *sk, */ -/* struct sockaddr *uaddr, int addr_len); */ - -/* int (*backlog_rcv) (struct sock *sk, */ -/* struct sk_buff *skb); */ - -/* /\* Keeping track of sk's, looking them up, and port selection methods. *\/ */ -/* void (*hash)(struct sock *sk); */ -/* void (*unhash)(struct sock *sk); */ -/* int (*get_port)(struct sock *sk, unsigned short snum); */ - -/* /\* Keeping track of sockets in use *\/ */ -/* #ifdef CONFIG_PROC_FS */ -/* unsigned int inuse_idx; */ -/* #endif */ /* /\* Memory pressure *\/ */ -/* void (*enter_memory_pressure)(struct sock *sk); */ /* atomic_t *memory_allocated; /\* Current allocated memory. *\/ */ /* struct percpu_counter *sockets_allocated; /\* Current number of sockets. *\/ */ /* /\* */ @@ -182,53 +224,184 @@ static struct proto alg_proto = { /* struct percpu_counter *orphan_count; */ -/* struct request_sock_ops *rsk_prot; */ -/* struct timewait_sock_ops *twsk_prot; */ - -/* union { */ -/* struct inet_hashinfo *hashinfo; */ -/* struct udp_table *udp_table; */ -/* struct raw_hashinfo *raw_hash; */ -/* } h; */ - .owner = THIS_MODULE, .name = "ALG", +}; -/* struct list_head node; */ -/* #ifdef SOCK_REFCNT_DEBUG */ -/* atomic_t socks; */ -/* #endif */ +static int data_release(struct socket *sock) +{ + struct sock *sk; + struct data_sock *dsk; + + sk = sock->sk; + if (unlikely(sk == NULL)) + return 0; + + dsk = data_sk(sk); + + local_bh_disable(); + sock_prot_inuse_add(sock_net(sk), &data_proto, -1); + local_bh_enable(); + + BUG_ON(dsk->master == NULL); + sock_put(&dsk->master->sk); + + sock_put(sk); +} + +static int do_data_sendmsg(struct kiocb *iocb, struct alg_sock *ask, + unsigned index, struct msghdr *m, size_t total_len) +{ + char *buf; + int res; + + // FIXME: locking + + // FIXME: make generic + BUG_ON(index != 0); + BUG_ON(ask->hash.init == 0); + + // FIXME: limit size, or use socket buffer + buf = kmalloc(total_len, GFP_KERNEL); + if (!buf) + return -ENOMEM; + + res = memcpy_fromiovec(buf, m->msg_iov, total_len); + if (res != 0) + goto err; + + // FIXME + res = _cryptodev_hash_update(&ask->hash, buf, total_len); + if (res < 0) + goto err; + + res = total_len; + +err: + kfree(buf); + return res; +} + +static int data_sendmsg(struct kiocb *iocb, struct socket *sock, + struct msghdr *m, size_t total_len) +{ + struct data_sock *dsk; + + dsk = data_sk(sock->sk); + BUG_ON(dsk->master == NULL); + return do_data_sendmsg(iocb, dsk->master, dsk->index, m, total_len); +} + +static int do_data_recvmsg(struct kiocb *iocb, struct alg_sock *ask, + unsigned index, struct msghdr *m, size_t total_len, + int flags) +{ + char digest[NCR_HASH_MAX_OUTPUT_SIZE]; + int res; + + // FIXME: locking + + // FIXME: make generic + BUG_ON(index != 0); + BUG_ON(ask->hash.init == 0); + + if (total_len < ask->hash.digestsize) + return -EINVAL; + + BUG_ON(ask->hash.digestsize > sizeof(digest)); + res = cryptodev_hash_final(&ask->hash, digest); + if (res < 0) + return res; + + res = memcpy_toiovec(m->msg_iov, digest, ask->hash.digestsize); + if (res != 0) + return res; + + res = cryptodev_hash_reset(&ask->hash); + if (res != 0) + return res; + + return ask->hash.digestsize; +} + +static int data_recvmsg(struct kiocb *iocb, struct socket *sock, + struct msghdr *m, size_t total_len, int flags) +{ + struct data_sock *dsk; + + dsk = data_sk(sock->sk); + BUG_ON(dsk->master == NULL); + return do_data_recvmsg(iocb, dsk->master, dsk->index, m, total_len, + flags); +} + +static const struct proto_ops data_proto_ops = { + .family = PF_ALG, + .owner = THIS_MODULE, + .release = data_release, + .bind = sock_no_bind, + .connect = sock_no_connect, + .socketpair = sock_no_socketpair, + .accept = sock_no_accept, + .getname = sock_no_getname, + .poll = sock_no_poll, + .ioctl = sock_no_ioctl, + .compat_ioctl = NULL, + .listen = sock_no_listen, + .shutdown = sock_no_shutdown, + .setsockopt = sock_no_setsockopt, + .getsockopt = sock_no_getsockopt, + .compat_setsockopt = NULL, + .compat_getsockopt = NULL, + .sendmsg = data_sendmsg, + .recvmsg = data_recvmsg, + .mmap = sock_no_mmap, + .sendpage = NULL, + .splice_read = NULL, }; static int alg_release(struct socket *sock) { struct sock *sk; struct alg_sock *ask; + size_t i; sk = sock->sk; - if (sk == NULL) + if (unlikely(sk == NULL)) return 0; sock->sk = NULL; + ask = alg_sk(sk); + // skb_queue_purge(&sk->sk_write_queue);??? - ask = alg_sk(sk); - if (ask->queued != NULL) { - local_bh_disable(); - sock_prot_inuse_add(sock_net(ask->queued), &alg_proto, -1); - local_bh_enable(); - sock_put(ask->queued); + for (i = 0; i < ask->num_slaves; i++) { + struct data_sock *dsk; + + dsk = ask->slaves[i]; + ask->slaves[i] = NULL; + + BUG_ON(dsk == NULL); + if (i >= ask->accept_idx) { // FIXME: cleaner - on last sock_put + local_bh_disable(); + sock_prot_inuse_add(sock_net(&dsk->sk), &data_proto, + -1); + local_bh_enable(); + } + sock_put(&dsk->sk); } if (ask->hash.init != 0) cryptodev_hash_deinit(&ask->hash); + // FIXME: on last sock_put local_bh_disable(); sock_prot_inuse_add(sock_net(sk), &alg_proto, -1); local_bh_enable(); sock_put(sk); + return 0; } @@ -237,18 +410,30 @@ static int alg_bind(struct socket *sock, struct sockaddr *myaddr, { struct sockaddr_alg *addr; struct alg_sock *ask; + int res; if (myaddr->sa_family != AF_ALG || sockaddr_len < sizeof(*addr)) return -EINVAL; addr = (struct sockaddr_alg *)myaddr; + if (memchr(addr->salg_type, '\0', sizeof(addr->salg_type)) == NULL) + return -EINVAL; + + ask = alg_sk(sock->sk); // FIXME: locking + if (ask->addr.salg_type[0] != 0) + return -EINVAL; // FIXME: better error code for "already bound"? + // FIXME if (strncmp(addr->salg_type, "hash", sizeof(addr->salg_type)) != 0) return -EINVAL; - ask = alg_sk(sock->sk); + BUG_ON(ask->hash.init != 0); + res = cryptodev_hash_init(&ask->hash, addr->salg_tfm, NULL, 0); + if (res != 0) + return res; + ask->addr = *addr; return 0; } @@ -256,121 +441,112 @@ static int alg_bind(struct socket *sock, struct sockaddr *myaddr, static int alg_accept(struct socket *sock, struct socket *newsock, int flags) { struct alg_sock *ask; - struct sock *newsk; + struct data_sock *dsk; // FIXME: locking ask = alg_sk(sock->sk); - if (ask->queued == NULL) + if (ask->accept_idx >= ask->num_slaves) return -EINVAL; - newsk = ask->queued; - ask->queued = NULL; - sock_graft(newsk, newsock); + dsk = ask->slaves[ask->accept_idx]; + ask->accept_idx++; + + sock_hold(&ask->sk); + dsk->master = ask; + + sock_hold(&dsk->sk); + sock_graft(&dsk->sk, newsock); return 0; } static int alg_listen(struct socket *sock, int len) { struct net *net; - struct sock *newsk; - struct alg_sock *ask, *newask; - int res; + struct alg_sock *ask; + struct data_sock *dsk; + size_t i; // FIXME: locking net = sock_net(sock->sk); ask = alg_sk(sock->sk); - if (ask->addr.salg_type[0] == 0) - return -EINVAL; // FIXME: better error code for "not bound"? + + if (ask->num_slaves != 0) + return -EINVAL; + // FIXME: type-specific if (len != 1) return -EINVAL; - if (ask->queued != NULL) - return -EINVAL; + BUG_ON(len > ARRAY_SIZE(ask->slaves)); - newsk = sk_alloc(net, PF_ALG, GFP_KERNEL, &alg_proto); - if (newsk == NULL) - return -ENOMEM; - newask = alg_sk(newsk); - // FIXME - res = cryptodev_hash_init(&newask->hash, ask->addr.salg_tfm, NULL, 0); - if (res != 0) { - sock_put(newsk); - return res; - } + for (i = 0; i < len; i++) { + struct sock *newsk; - sock_init_data(NULL, newsk); - local_bh_disable(); - sock_prot_inuse_add(net, &alg_proto, 1); - local_bh_enable(); + newsk = sk_alloc(net, PF_ALG, GFP_KERNEL, &data_proto); + if (newsk == NULL) + goto err_partial; + sock_init_data(NULL, newsk); - ask->queued = newsk; + local_bh_disable(); + sock_prot_inuse_add(net, &data_proto, 1); + local_bh_enable(); + + dsk = data_sk(newsk); + dsk->index = i; + + ask->slaves[i] = dsk; + } + + ask->num_slaves = len; return 0; + +err_partial: + while (i != 0) { + i--; + + dsk = ask->slaves[i]; + ask->slaves[i] = NULL; + + BUG_ON(dsk != NULL); + + // FIXME: cleaner - on last sock_put + local_bh_disable(); + sock_prot_inuse_add(sock_net(&dsk->sk), &data_proto, -1); + local_bh_enable(); + + sock_put(&dsk->sk); + } + return -ENOMEM; } static int alg_sendmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *m, size_t total_len) { struct alg_sock *ask; - char *buf; - int res; - // FIXME: locking ask = alg_sk(sock->sk); - if (ask->hash.init == 0) - return -EINVAL; - - // FIXME: limit size, or use socket buffer - buf = kmalloc(total_len, GFP_KERNEL); - if (!buf) - return -ENOMEM; - - res = memcpy_fromiovec(buf, m->msg_iov, total_len); - if (res != 0) - goto err; - - // FIXME - res = _cryptodev_hash_update(&ask->hash, buf, total_len); - if (res < 0) - goto err; - res = total_len; + // FIXME: locking + if (ask->addr.salg_type[0] == 0) + return -EINVAL; // FIXME: better error code for "not bound"? -err: - kfree(buf); - return res; + return do_data_sendmsg(iocb, ask, 0, m, total_len); } static int alg_recvmsg(struct kiocb *iocb, struct socket *sock, struct msghdr *m, size_t total_len, int flags) { - char digest[NCR_HASH_MAX_OUTPUT_SIZE]; struct alg_sock *ask; - int res; // FIXME: locking ask = alg_sk(sock->sk); - if (ask->hash.init == 0) - return -EINVAL; - if (total_len < ask->hash.digestsize) - return -EINVAL; - // FIXME - BUG_ON(ask->hash.digestsize > sizeof(digest)); - res = cryptodev_hash_final(&ask->hash, digest); - if (res < 0) - return res; - - res = memcpy_toiovec(m->msg_iov, digest, ask->hash.digestsize); - if (res != 0) - return res; - - res = cryptodev_hash_reset(&ask->hash); - if (res != 0) - return res; + // FIXME: locking + if (ask->addr.salg_type[0] == 0) + return -EINVAL; // FIXME: better error code for "not bound"? - return ask->hash.digestsize; + return do_data_recvmsg(iocb, ask, 0, m, total_len, flags); } static const struct proto_ops alg_proto_ops = { @@ -439,15 +615,21 @@ static int __init init_cryptodev(void) if (unlikely(rc != 0)) goto err; + rc = proto_register(&data_proto, 1); + if (unlikely(rc != 0)) + goto err_alg_proto; + rc = sock_register(&alg_pf); if (unlikely(rc != 0)) - goto err_proto; + goto err_data_proto; printk(KERN_INFO PFX "driver loaded.\n"); return 0; -err_proto: +err_data_proto: + proto_unregister(&data_proto); +err_alg_proto: proto_unregister(&alg_proto); err: printk(KERN_ERR PFX "driver registration failed\n"); @@ -457,6 +639,7 @@ err: static void __exit exit_cryptodev(void) { sock_unregister(PF_ALG); + proto_unregister(&data_proto); proto_unregister(&alg_proto); printk(KERN_INFO PFX "driver unloaded.\n"); diff --git a/tests/ncr.c b/tests/ncr.c index 207a233..362cdaa 100644 --- a/tests/ncr.c +++ b/tests/ncr.c @@ -96,14 +96,17 @@ struct hash_vectors_st { #define HASH_DATA_SIZE 64 /* SHA1 and other hashes */ -static int test_ncr_hash() +static int test_ncr_hash(int with_accept) { uint8_t data[HASH_DATA_SIZE]; int i, j; ssize_t data_size; /* convert it to key */ - fprintf(stdout, "Tests on Hashes\n"); + if (with_accept) + fprintf(stdout, "Tests on Hashes with accept()\n"); + else + fprintf(stdout, "Tests on Hashes without accept()\n"); for (i = 0; i < sizeof(hash_vectors) / sizeof(hash_vectors[0]); i++) { struct sockaddr_alg salg; int fd, hfd; @@ -124,18 +127,21 @@ static int test_ncr_hash() return 1; } - if (listen(fd, 1) != 0) { - perror("listen()"); - return 1; - } + if (with_accept) { + if (listen(fd, 1) != 0) { + perror("listen()"); + return 1; + } - hfd = accept(fd, NULL, NULL); - if (hfd < 0) { - perror("accept()"); - return 1; - } + hfd = accept(fd, NULL, NULL); + if (hfd < 0) { + perror("accept()"); + return 1; + } - close(fd); + close(fd); + } else + hfd = fd; errno = 0; if (write(hfd, hash_vectors[i].plaintext, @@ -182,7 +188,9 @@ static int test_ncr_hash() int main() { - if (test_ncr_hash()) + if (test_ncr_hash(0)) + return 1; + if (test_ncr_hash(1)) return 1; return 0; |