summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorMiloslav Trmač <mitr@redhat.com>2010-09-13 22:50:20 +0200
committerMiloslav Trmač <mitr@redhat.com>2010-09-13 23:13:14 +0200
commitb6d45c17d91dadaeed30131d8e8389cf33678747 (patch)
tree1f5c97cb182e2935b4b2ec5200cfb6cd4b6184c8
parent244f7c711cdbd3f9ab35050dc7bc623a54d1dfa0 (diff)
downloadcryptodev-linux-b6d45c17d91dadaeed30131d8e8389cf33678747.tar.gz
cryptodev-linux-b6d45c17d91dadaeed30131d8e8389cf33678747.tar.xz
cryptodev-linux-b6d45c17d91dadaeed30131d8e8389cf33678747.zip
Use separate protocols for master and slave sockets
-rw-r--r--cryptodev_main.c433
-rw-r--r--tests/ncr.c34
2 files changed, 329 insertions, 138 deletions
diff --git a/cryptodev_main.c b/cryptodev_main.c
index 733e9f1..49e362b 100644
--- a/cryptodev_main.c
+++ b/cryptodev_main.c
@@ -94,33 +94,95 @@ int __get_userbuf(uint8_t __user * addr, uint32_t len, int write,
/* ====== /dev/crypto ====== */
+/* A socket merely used for I/O. */
+struct data_sock {
+ struct sock sk;
+ struct alg_sock *master; /* See below for reference counting */
+ unsigned index;
+};
+
+/* A socket holding a crypto_tfm. */
struct alg_sock {
struct sock sk;
struct sockaddr_alg addr;
- struct sock *queued;
+ struct data_sock *slaves[1];
+ unsigned num_slaves, accept_idx;
struct hash_data hash;
};
+/* Socket reference counting:
+ The primary reference to alg_sock is from user-space.
+
+ After listen(), data_sock instances are referenced through alg_sock::slaves.
+ Unaccepted data_sock instancess do not not hold a reference back to alg_sock.
+
+ accept() adds an alg_sock reference through data_sock::master.
+
+ This creates a reference loop alg_sock<->data_sock. The loop is broken on
+ last close() of any side of the pair, i.e. proto_ops::release, when
+ the reference is dropped. */
+
+static struct data_sock *data_sk(struct sock *sk)
+{
+ return container_of(sk, struct data_sock, sk);
+}
+
static struct alg_sock *alg_sk(struct sock *sk)
{
return container_of(sk, struct alg_sock, sk);
}
-static struct proto alg_proto = {
-/* void (*close)(struct sock *sk, */
-/* long timeout); */
-/* int (*connect)(struct sock *sk, */
-/* struct sockaddr *uaddr, */
-/* int addr_len); */
-/* int (*disconnect)(struct sock *sk, int flags); */
+static struct proto data_proto = {
+/* int (*init)(struct sock *sk); */
+/* void (*destroy)(struct sock *sk); */
+/* int (*setsockopt)(struct sock *sk, int level, */
+/* int optname, char __user *optval, */
+/* unsigned int optlen); */
+/* int (*getsockopt)(struct sock *sk, int level, */
+/* int optname, char __user *optval, */
+/* int __user *option); */
+/* #ifdef CONFIG_COMPAT */
+/* int (*compat_setsockopt)(struct sock *sk, */
+/* int level, */
+/* int optname, char __user *optval, */
+/* unsigned int optlen); */
+/* int (*compat_getsockopt)(struct sock *sk, */
+/* int level, */
+/* int optname, char __user *optval, */
+/* int __user *option); */
+/* #endif */
+/* int (*recvmsg)(struct kiocb *iocb, struct sock *sk, */
+/* struct msghdr *msg, */
+/* size_t len, int noblock, int flags, */
+/* int *addr_len); */
-/* struct sock * (*accept) (struct sock *sk, int flags, int *err); */
+/* /\* Memory pressure *\/ */
+/* atomic_t *memory_allocated; /\* Current allocated memory. *\/ */
+/* struct percpu_counter *sockets_allocated; /\* Current number of sockets. *\/ */
+/* /\* */
+/* * Pressure flag: try to collapse. */
+/* * Technical note: it is used by multiple contexts non atomically. */
+/* * All the __sk_mem_schedule() is of this nature: accounting */
+/* * is strict, actions are advisory and have some latency. */
+/* *\/ */
+/* int *memory_pressure; */
+/* int *sysctl_mem; */
+/* int *sysctl_wmem; */
+/* int *sysctl_rmem; */
+/* int max_header; */
-/* int (*ioctl)(struct sock *sk, int cmd, */
-/* unsigned long arg); */
+ .obj_size = sizeof(struct data_sock),
+/* int slab_flags; */
+
+/* struct percpu_counter *orphan_count; */
+
+ .owner = THIS_MODULE,
+ .name = "ALG-data",
+};
+
+static struct proto alg_proto = {
/* int (*init)(struct sock *sk); */
/* void (*destroy)(struct sock *sk); */
-/* void (*shutdown)(struct sock *sk, int how); */
/* int (*setsockopt)(struct sock *sk, int level, */
/* int optname, char __user *optval, */
/* unsigned int optlen); */
@@ -137,32 +199,12 @@ static struct proto alg_proto = {
/* int optname, char __user *optval, */
/* int __user *option); */
/* #endif */
-/* int (*sendmsg)(struct kiocb *iocb, struct sock *sk, */
-/* struct msghdr *msg, size_t len); */
/* int (*recvmsg)(struct kiocb *iocb, struct sock *sk, */
/* struct msghdr *msg, */
/* size_t len, int noblock, int flags, */
/* int *addr_len); */
-/* int (*sendpage)(struct sock *sk, struct page *page, */
-/* int offset, size_t size, int flags); */
-/* int (*bind)(struct sock *sk, */
-/* struct sockaddr *uaddr, int addr_len); */
-
-/* int (*backlog_rcv) (struct sock *sk, */
-/* struct sk_buff *skb); */
-
-/* /\* Keeping track of sk's, looking them up, and port selection methods. *\/ */
-/* void (*hash)(struct sock *sk); */
-/* void (*unhash)(struct sock *sk); */
-/* int (*get_port)(struct sock *sk, unsigned short snum); */
-
-/* /\* Keeping track of sockets in use *\/ */
-/* #ifdef CONFIG_PROC_FS */
-/* unsigned int inuse_idx; */
-/* #endif */
/* /\* Memory pressure *\/ */
-/* void (*enter_memory_pressure)(struct sock *sk); */
/* atomic_t *memory_allocated; /\* Current allocated memory. *\/ */
/* struct percpu_counter *sockets_allocated; /\* Current number of sockets. *\/ */
/* /\* */
@@ -182,53 +224,184 @@ static struct proto alg_proto = {
/* struct percpu_counter *orphan_count; */
-/* struct request_sock_ops *rsk_prot; */
-/* struct timewait_sock_ops *twsk_prot; */
-
-/* union { */
-/* struct inet_hashinfo *hashinfo; */
-/* struct udp_table *udp_table; */
-/* struct raw_hashinfo *raw_hash; */
-/* } h; */
-
.owner = THIS_MODULE,
.name = "ALG",
+};
-/* struct list_head node; */
-/* #ifdef SOCK_REFCNT_DEBUG */
-/* atomic_t socks; */
-/* #endif */
+static int data_release(struct socket *sock)
+{
+ struct sock *sk;
+ struct data_sock *dsk;
+
+ sk = sock->sk;
+ if (unlikely(sk == NULL))
+ return 0;
+
+ dsk = data_sk(sk);
+
+ local_bh_disable();
+ sock_prot_inuse_add(sock_net(sk), &data_proto, -1);
+ local_bh_enable();
+
+ BUG_ON(dsk->master == NULL);
+ sock_put(&dsk->master->sk);
+
+ sock_put(sk);
+}
+
+static int do_data_sendmsg(struct kiocb *iocb, struct alg_sock *ask,
+ unsigned index, struct msghdr *m, size_t total_len)
+{
+ char *buf;
+ int res;
+
+ // FIXME: locking
+
+ // FIXME: make generic
+ BUG_ON(index != 0);
+ BUG_ON(ask->hash.init == 0);
+
+ // FIXME: limit size, or use socket buffer
+ buf = kmalloc(total_len, GFP_KERNEL);
+ if (!buf)
+ return -ENOMEM;
+
+ res = memcpy_fromiovec(buf, m->msg_iov, total_len);
+ if (res != 0)
+ goto err;
+
+ // FIXME
+ res = _cryptodev_hash_update(&ask->hash, buf, total_len);
+ if (res < 0)
+ goto err;
+
+ res = total_len;
+
+err:
+ kfree(buf);
+ return res;
+}
+
+static int data_sendmsg(struct kiocb *iocb, struct socket *sock,
+ struct msghdr *m, size_t total_len)
+{
+ struct data_sock *dsk;
+
+ dsk = data_sk(sock->sk);
+ BUG_ON(dsk->master == NULL);
+ return do_data_sendmsg(iocb, dsk->master, dsk->index, m, total_len);
+}
+
+static int do_data_recvmsg(struct kiocb *iocb, struct alg_sock *ask,
+ unsigned index, struct msghdr *m, size_t total_len,
+ int flags)
+{
+ char digest[NCR_HASH_MAX_OUTPUT_SIZE];
+ int res;
+
+ // FIXME: locking
+
+ // FIXME: make generic
+ BUG_ON(index != 0);
+ BUG_ON(ask->hash.init == 0);
+
+ if (total_len < ask->hash.digestsize)
+ return -EINVAL;
+
+ BUG_ON(ask->hash.digestsize > sizeof(digest));
+ res = cryptodev_hash_final(&ask->hash, digest);
+ if (res < 0)
+ return res;
+
+ res = memcpy_toiovec(m->msg_iov, digest, ask->hash.digestsize);
+ if (res != 0)
+ return res;
+
+ res = cryptodev_hash_reset(&ask->hash);
+ if (res != 0)
+ return res;
+
+ return ask->hash.digestsize;
+}
+
+static int data_recvmsg(struct kiocb *iocb, struct socket *sock,
+ struct msghdr *m, size_t total_len, int flags)
+{
+ struct data_sock *dsk;
+
+ dsk = data_sk(sock->sk);
+ BUG_ON(dsk->master == NULL);
+ return do_data_recvmsg(iocb, dsk->master, dsk->index, m, total_len,
+ flags);
+}
+
+static const struct proto_ops data_proto_ops = {
+ .family = PF_ALG,
+ .owner = THIS_MODULE,
+ .release = data_release,
+ .bind = sock_no_bind,
+ .connect = sock_no_connect,
+ .socketpair = sock_no_socketpair,
+ .accept = sock_no_accept,
+ .getname = sock_no_getname,
+ .poll = sock_no_poll,
+ .ioctl = sock_no_ioctl,
+ .compat_ioctl = NULL,
+ .listen = sock_no_listen,
+ .shutdown = sock_no_shutdown,
+ .setsockopt = sock_no_setsockopt,
+ .getsockopt = sock_no_getsockopt,
+ .compat_setsockopt = NULL,
+ .compat_getsockopt = NULL,
+ .sendmsg = data_sendmsg,
+ .recvmsg = data_recvmsg,
+ .mmap = sock_no_mmap,
+ .sendpage = NULL,
+ .splice_read = NULL,
};
static int alg_release(struct socket *sock)
{
struct sock *sk;
struct alg_sock *ask;
+ size_t i;
sk = sock->sk;
- if (sk == NULL)
+ if (unlikely(sk == NULL))
return 0;
sock->sk = NULL;
+ ask = alg_sk(sk);
+
// skb_queue_purge(&sk->sk_write_queue);???
- ask = alg_sk(sk);
- if (ask->queued != NULL) {
- local_bh_disable();
- sock_prot_inuse_add(sock_net(ask->queued), &alg_proto, -1);
- local_bh_enable();
- sock_put(ask->queued);
+ for (i = 0; i < ask->num_slaves; i++) {
+ struct data_sock *dsk;
+
+ dsk = ask->slaves[i];
+ ask->slaves[i] = NULL;
+
+ BUG_ON(dsk == NULL);
+ if (i >= ask->accept_idx) { // FIXME: cleaner - on last sock_put
+ local_bh_disable();
+ sock_prot_inuse_add(sock_net(&dsk->sk), &data_proto,
+ -1);
+ local_bh_enable();
+ }
+ sock_put(&dsk->sk);
}
if (ask->hash.init != 0)
cryptodev_hash_deinit(&ask->hash);
+ // FIXME: on last sock_put
local_bh_disable();
sock_prot_inuse_add(sock_net(sk), &alg_proto, -1);
local_bh_enable();
sock_put(sk);
+
return 0;
}
@@ -237,18 +410,30 @@ static int alg_bind(struct socket *sock, struct sockaddr *myaddr,
{
struct sockaddr_alg *addr;
struct alg_sock *ask;
+ int res;
if (myaddr->sa_family != AF_ALG || sockaddr_len < sizeof(*addr))
return -EINVAL;
addr = (struct sockaddr_alg *)myaddr;
+ if (memchr(addr->salg_type, '\0', sizeof(addr->salg_type)) == NULL)
+ return -EINVAL;
+
+ ask = alg_sk(sock->sk);
// FIXME: locking
+ if (ask->addr.salg_type[0] != 0)
+ return -EINVAL; // FIXME: better error code for "already bound"?
+
// FIXME
if (strncmp(addr->salg_type, "hash", sizeof(addr->salg_type)) != 0)
return -EINVAL;
- ask = alg_sk(sock->sk);
+ BUG_ON(ask->hash.init != 0);
+ res = cryptodev_hash_init(&ask->hash, addr->salg_tfm, NULL, 0);
+ if (res != 0)
+ return res;
+
ask->addr = *addr;
return 0;
}
@@ -256,121 +441,112 @@ static int alg_bind(struct socket *sock, struct sockaddr *myaddr,
static int alg_accept(struct socket *sock, struct socket *newsock, int flags)
{
struct alg_sock *ask;
- struct sock *newsk;
+ struct data_sock *dsk;
// FIXME: locking
ask = alg_sk(sock->sk);
- if (ask->queued == NULL)
+ if (ask->accept_idx >= ask->num_slaves)
return -EINVAL;
- newsk = ask->queued;
- ask->queued = NULL;
- sock_graft(newsk, newsock);
+ dsk = ask->slaves[ask->accept_idx];
+ ask->accept_idx++;
+
+ sock_hold(&ask->sk);
+ dsk->master = ask;
+
+ sock_hold(&dsk->sk);
+ sock_graft(&dsk->sk, newsock);
return 0;
}
static int alg_listen(struct socket *sock, int len)
{
struct net *net;
- struct sock *newsk;
- struct alg_sock *ask, *newask;
- int res;
+ struct alg_sock *ask;
+ struct data_sock *dsk;
+ size_t i;
// FIXME: locking
net = sock_net(sock->sk);
ask = alg_sk(sock->sk);
- if (ask->addr.salg_type[0] == 0)
- return -EINVAL; // FIXME: better error code for "not bound"?
+
+ if (ask->num_slaves != 0)
+ return -EINVAL;
+
// FIXME: type-specific
if (len != 1)
return -EINVAL;
- if (ask->queued != NULL)
- return -EINVAL;
+ BUG_ON(len > ARRAY_SIZE(ask->slaves));
- newsk = sk_alloc(net, PF_ALG, GFP_KERNEL, &alg_proto);
- if (newsk == NULL)
- return -ENOMEM;
- newask = alg_sk(newsk);
- // FIXME
- res = cryptodev_hash_init(&newask->hash, ask->addr.salg_tfm, NULL, 0);
- if (res != 0) {
- sock_put(newsk);
- return res;
- }
+ for (i = 0; i < len; i++) {
+ struct sock *newsk;
- sock_init_data(NULL, newsk);
- local_bh_disable();
- sock_prot_inuse_add(net, &alg_proto, 1);
- local_bh_enable();
+ newsk = sk_alloc(net, PF_ALG, GFP_KERNEL, &data_proto);
+ if (newsk == NULL)
+ goto err_partial;
+ sock_init_data(NULL, newsk);
- ask->queued = newsk;
+ local_bh_disable();
+ sock_prot_inuse_add(net, &data_proto, 1);
+ local_bh_enable();
+
+ dsk = data_sk(newsk);
+ dsk->index = i;
+
+ ask->slaves[i] = dsk;
+ }
+
+ ask->num_slaves = len;
return 0;
+
+err_partial:
+ while (i != 0) {
+ i--;
+
+ dsk = ask->slaves[i];
+ ask->slaves[i] = NULL;
+
+ BUG_ON(dsk != NULL);
+
+ // FIXME: cleaner - on last sock_put
+ local_bh_disable();
+ sock_prot_inuse_add(sock_net(&dsk->sk), &data_proto, -1);
+ local_bh_enable();
+
+ sock_put(&dsk->sk);
+ }
+ return -ENOMEM;
}
static int alg_sendmsg(struct kiocb *iocb, struct socket *sock,
struct msghdr *m, size_t total_len)
{
struct alg_sock *ask;
- char *buf;
- int res;
- // FIXME: locking
ask = alg_sk(sock->sk);
- if (ask->hash.init == 0)
- return -EINVAL;
-
- // FIXME: limit size, or use socket buffer
- buf = kmalloc(total_len, GFP_KERNEL);
- if (!buf)
- return -ENOMEM;
-
- res = memcpy_fromiovec(buf, m->msg_iov, total_len);
- if (res != 0)
- goto err;
-
- // FIXME
- res = _cryptodev_hash_update(&ask->hash, buf, total_len);
- if (res < 0)
- goto err;
- res = total_len;
+ // FIXME: locking
+ if (ask->addr.salg_type[0] == 0)
+ return -EINVAL; // FIXME: better error code for "not bound"?
-err:
- kfree(buf);
- return res;
+ return do_data_sendmsg(iocb, ask, 0, m, total_len);
}
static int alg_recvmsg(struct kiocb *iocb, struct socket *sock,
struct msghdr *m, size_t total_len, int flags)
{
- char digest[NCR_HASH_MAX_OUTPUT_SIZE];
struct alg_sock *ask;
- int res;
// FIXME: locking
ask = alg_sk(sock->sk);
- if (ask->hash.init == 0)
- return -EINVAL;
- if (total_len < ask->hash.digestsize)
- return -EINVAL;
- // FIXME
- BUG_ON(ask->hash.digestsize > sizeof(digest));
- res = cryptodev_hash_final(&ask->hash, digest);
- if (res < 0)
- return res;
-
- res = memcpy_toiovec(m->msg_iov, digest, ask->hash.digestsize);
- if (res != 0)
- return res;
-
- res = cryptodev_hash_reset(&ask->hash);
- if (res != 0)
- return res;
+ // FIXME: locking
+ if (ask->addr.salg_type[0] == 0)
+ return -EINVAL; // FIXME: better error code for "not bound"?
- return ask->hash.digestsize;
+ return do_data_recvmsg(iocb, ask, 0, m, total_len, flags);
}
static const struct proto_ops alg_proto_ops = {
@@ -439,15 +615,21 @@ static int __init init_cryptodev(void)
if (unlikely(rc != 0))
goto err;
+ rc = proto_register(&data_proto, 1);
+ if (unlikely(rc != 0))
+ goto err_alg_proto;
+
rc = sock_register(&alg_pf);
if (unlikely(rc != 0))
- goto err_proto;
+ goto err_data_proto;
printk(KERN_INFO PFX "driver loaded.\n");
return 0;
-err_proto:
+err_data_proto:
+ proto_unregister(&data_proto);
+err_alg_proto:
proto_unregister(&alg_proto);
err:
printk(KERN_ERR PFX "driver registration failed\n");
@@ -457,6 +639,7 @@ err:
static void __exit exit_cryptodev(void)
{
sock_unregister(PF_ALG);
+ proto_unregister(&data_proto);
proto_unregister(&alg_proto);
printk(KERN_INFO PFX "driver unloaded.\n");
diff --git a/tests/ncr.c b/tests/ncr.c
index 207a233..362cdaa 100644
--- a/tests/ncr.c
+++ b/tests/ncr.c
@@ -96,14 +96,17 @@ struct hash_vectors_st {
#define HASH_DATA_SIZE 64
/* SHA1 and other hashes */
-static int test_ncr_hash()
+static int test_ncr_hash(int with_accept)
{
uint8_t data[HASH_DATA_SIZE];
int i, j;
ssize_t data_size;
/* convert it to key */
- fprintf(stdout, "Tests on Hashes\n");
+ if (with_accept)
+ fprintf(stdout, "Tests on Hashes with accept()\n");
+ else
+ fprintf(stdout, "Tests on Hashes without accept()\n");
for (i = 0; i < sizeof(hash_vectors) / sizeof(hash_vectors[0]); i++) {
struct sockaddr_alg salg;
int fd, hfd;
@@ -124,18 +127,21 @@ static int test_ncr_hash()
return 1;
}
- if (listen(fd, 1) != 0) {
- perror("listen()");
- return 1;
- }
+ if (with_accept) {
+ if (listen(fd, 1) != 0) {
+ perror("listen()");
+ return 1;
+ }
- hfd = accept(fd, NULL, NULL);
- if (hfd < 0) {
- perror("accept()");
- return 1;
- }
+ hfd = accept(fd, NULL, NULL);
+ if (hfd < 0) {
+ perror("accept()");
+ return 1;
+ }
- close(fd);
+ close(fd);
+ } else
+ hfd = fd;
errno = 0;
if (write(hfd, hash_vectors[i].plaintext,
@@ -182,7 +188,9 @@ static int test_ncr_hash()
int main()
{
- if (test_ncr_hash())
+ if (test_ncr_hash(0))
+ return 1;
+ if (test_ncr_hash(1))
return 1;
return 0;