diff options
Diffstat (limited to 'src')
24 files changed, 665 insertions, 367 deletions
diff --git a/src/responder/autofs/autofs_private.h b/src/responder/autofs/autofs_private.h index efe1d623f..6a39b17ad 100644 --- a/src/responder/autofs/autofs_private.h +++ b/src/responder/autofs/autofs_private.h @@ -33,6 +33,10 @@ struct autofs_ctx { hash_table_t *maps; }; +struct autofs_state_ctx { + char *automntmap_name; +}; + struct autofs_cmd_ctx { struct cli_ctx *cctx; char *mapname; @@ -74,6 +78,7 @@ struct autofs_map_ctx { }; struct sss_cmd_table *get_autofs_cmds(void); +int autofs_connection_setup(struct cli_ctx *cctx); void autofs_map_hash_delete_cb(hash_entry_t *item, hash_destroy_enum deltype, void *pvt); diff --git a/src/responder/autofs/autofssrv.c b/src/responder/autofs/autofssrv.c index 9647c3473..c72f3c1f7 100644 --- a/src/responder/autofs/autofssrv.c +++ b/src/responder/autofs/autofssrv.c @@ -131,6 +131,7 @@ autofs_process_init(TALLOC_CTX *mem_ctx, &monitor_autofs_methods, "autofs", &autofs_dp_methods.vtable, + autofs_connection_setup, &rctx); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "sss_process_init() failed\n"); diff --git a/src/responder/autofs/autofssrv_cmd.c b/src/responder/autofs/autofssrv_cmd.c index 42754aceb..9666ab2d1 100644 --- a/src/responder/autofs/autofssrv_cmd.c +++ b/src/responder/autofs/autofssrv_cmd.c @@ -240,6 +240,7 @@ static int sss_autofs_cmd_setautomntent(struct cli_ctx *client) { struct autofs_cmd_ctx *cmdctx; + struct cli_protocol *pctx; uint8_t *body; size_t blen; errno_t ret = EOK; @@ -254,7 +255,9 @@ sss_autofs_cmd_setautomntent(struct cli_ctx *client) } cmdctx->cctx = client; - sss_packet_get_body(client->creq->in, &body, &blen); + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + + sss_packet_get_body(pctx->creq->in, &body, &blen); /* if not terminated fail */ if (body[blen -1] != '\0') { @@ -290,9 +293,9 @@ static void sss_autofs_cmd_setautomntent_done(struct tevent_req *req) { struct autofs_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct autofs_cmd_ctx); + struct cli_protocol *pctx; errno_t ret; errno_t reqret; - struct sss_packet *packet; uint8_t *body; size_t blen; @@ -306,26 +309,27 @@ static void sss_autofs_cmd_setautomntent_done(struct tevent_req *req) return; } + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* Either we succeeded or no domains were eligible */ - ret = sss_packet_new(cmdctx->cctx->creq, 0, - sss_packet_get_cmd(cmdctx->cctx->creq->in), - &cmdctx->cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret == EOK) { if (reqret == ENOENT) { DEBUG(SSSDBG_TRACE_FUNC, "setautomntent did not find requested map\n"); /* Notify the caller that this entry wasn't found */ - sss_cmd_empty_packet(cmdctx->cctx->creq->out); + sss_cmd_empty_packet(pctx->creq->out); } else { DEBUG(SSSDBG_TRACE_FUNC, "setautomntent found data\n"); - packet = cmdctx->cctx->creq->out; - ret = sss_packet_grow(packet, 2*sizeof(uint32_t)); + ret = sss_packet_grow(pctx->creq->out, 2*sizeof(uint32_t)); if (ret != EOK) { DEBUG(SSSDBG_CRIT_FAILURE, "Couldn't grow the packet\n"); talloc_free(cmdctx); return; } - sss_packet_get_body(packet, &body, &blen); + sss_packet_get_body(pctx->creq->out, &body, &blen); /* Got some results */ SAFEALIGN_SETMEM_UINT32(body, 1, NULL); @@ -417,10 +421,13 @@ setautomntent_send(TALLOC_CTX *mem_ctx, struct setautomntent_state *state; struct cli_ctx *client = cmdctx->cctx; struct autofs_dom_ctx *dctx; - struct autofs_ctx *actx = - talloc_get_type(client->rctx->pvt_ctx, struct autofs_ctx); + struct autofs_ctx *actx; + struct autofs_state_ctx *state_ctx; struct setautomntent_lookup_ctx *lookup_ctx; + actx = talloc_get_type(client->rctx->pvt_ctx, struct autofs_ctx); + state_ctx = talloc_get_type(client->state_ctx, struct autofs_state_ctx); + req = tevent_req_create(mem_ctx, &state, struct setautomntent_state); if (!req) { DEBUG(SSSDBG_FATAL_FAILURE, @@ -458,8 +465,8 @@ setautomntent_send(TALLOC_CTX *mem_ctx, goto fail; } - client->automntmap_name = talloc_strdup(client, rawname); - if (!client->automntmap_name) { + state_ctx->automntmap_name = talloc_strdup(client, rawname); + if (!state_ctx->automntmap_name) { ret = ENOMEM; goto fail; } @@ -468,8 +475,8 @@ setautomntent_send(TALLOC_CTX *mem_ctx, dctx->domain = client->rctx->domains; cmdctx->check_next = true; - client->automntmap_name = talloc_strdup(client, state->mapname); - if (!client->automntmap_name) { + state_ctx->automntmap_name = talloc_strdup(client, state->mapname); + if (!state_ctx->automntmap_name) { ret = ENOMEM; goto fail; } @@ -909,6 +916,7 @@ sss_autofs_cmd_getautomntent(struct cli_ctx *client) struct autofs_cmd_ctx *cmdctx; struct autofs_map_ctx *map; struct autofs_ctx *actx; + struct cli_protocol *pctx; uint8_t *body; size_t blen; errno_t ret; @@ -930,8 +938,10 @@ sss_autofs_cmd_getautomntent(struct cli_ctx *client) return EIO; } + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* get autofs map name and index to query */ - sss_packet_get_body(client->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); SAFEALIGN_COPY_UINT32_CHECK(&namelen, body+c, blen, &c); @@ -1054,7 +1064,7 @@ getautomntent_process(struct autofs_cmd_ctx *cmdctx, struct autofs_map_ctx *map, uint32_t cursor, uint32_t max_entries) { - struct cli_ctx *client = cmdctx->cctx; + struct cli_protocol *pctx; errno_t ret; struct ldb_message *entry; size_t rp; @@ -1062,10 +1072,12 @@ getautomntent_process(struct autofs_cmd_ctx *cmdctx, uint8_t *body; size_t blen; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* create response packet */ - ret = sss_packet_new(client->creq, 0, - sss_packet_get_cmd(client->creq->in), - &client->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } @@ -1073,7 +1085,7 @@ getautomntent_process(struct autofs_cmd_ctx *cmdctx, if (!map->map || !map->entries || !map->entries[0] || cursor >= map->entry_count) { DEBUG(SSSDBG_MINOR_FAILURE, "No entries found\n"); - ret = sss_cmd_empty_packet(client->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); if (ret != EOK) { return autofs_cmd_done(cmdctx, ret); } @@ -1081,7 +1093,7 @@ getautomntent_process(struct autofs_cmd_ctx *cmdctx, } /* allocate memory for number of entries in the packet */ - ret = sss_packet_grow(client->creq->out, sizeof(uint32_t)); + ret = sss_packet_grow(pctx->creq->out, sizeof(uint32_t)); if (ret != EOK) { DEBUG(SSSDBG_OP_FAILURE, "Cannot grow packet\n"); goto done; @@ -1097,7 +1109,7 @@ getautomntent_process(struct autofs_cmd_ctx *cmdctx, entry = map->entries[cursor]; cursor++; - ret = fill_autofs_entry(entry, client->creq->out, &rp); + ret = fill_autofs_entry(entry, pctx->creq->out, &rp); if (ret != EOK) { DEBUG(SSSDBG_MINOR_FAILURE, "Cannot fill entry %d/%d, skipping\n", i, stop); @@ -1108,15 +1120,15 @@ getautomntent_process(struct autofs_cmd_ctx *cmdctx, /* packet grows in fill_autofs_entry, body pointer may change, * thus we have to obtain it here */ - sss_packet_get_body(client->creq->out, &body, &blen); + sss_packet_get_body(pctx->creq->out, &body, &blen); rp = 0; SAFEALIGN_SET_UINT32(&body[rp], nentries, &rp); ret = EOK; done: - sss_packet_set_error(client->creq->out, ret); - sss_cmd_done(client, cmdctx); + sss_packet_set_error(pctx->creq->out, ret); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -1187,6 +1199,7 @@ sss_autofs_cmd_getautomntbyname(struct cli_ctx *client) struct autofs_cmd_ctx *cmdctx; struct autofs_map_ctx *map; struct autofs_ctx *actx; + struct cli_protocol *pctx; uint8_t *body; size_t blen; uint32_t namelen; @@ -1208,8 +1221,10 @@ sss_autofs_cmd_getautomntbyname(struct cli_ctx *client) return EIO; } + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* get autofs map name and index to query */ - sss_packet_get_body(client->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); /* FIXME - split out a function to get string from <len><str>\0 */ SAFEALIGN_COPY_UINT32_CHECK(&namelen, body+c, blen, &c); @@ -1352,7 +1367,7 @@ getautomntbyname_process(struct autofs_cmd_ctx *cmdctx, struct autofs_map_ctx *map, const char *key) { - struct cli_ctx *client = cmdctx->cctx; + struct cli_protocol *pctx; errno_t ret; size_t i; const char *k; @@ -1362,17 +1377,19 @@ getautomntbyname_process(struct autofs_cmd_ctx *cmdctx, uint8_t *body; size_t blen, rp; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* create response packet */ - ret = sss_packet_new(client->creq, 0, - sss_packet_get_cmd(client->creq->in), - &client->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } if (!map->map || !map->entries || !map->entries[0]) { DEBUG(SSSDBG_MINOR_FAILURE, "No entries found\n"); - ret = sss_cmd_empty_packet(client->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); if (ret != EOK) { return autofs_cmd_done(cmdctx, ret); } @@ -1395,7 +1412,7 @@ getautomntbyname_process(struct autofs_cmd_ctx *cmdctx, if (i >= map->entry_count) { DEBUG(SSSDBG_MINOR_FAILURE, "No key named [%s] found\n", key); - ret = sss_cmd_empty_packet(client->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); if (ret != EOK) { return autofs_cmd_done(cmdctx, ret); } @@ -1408,12 +1425,12 @@ getautomntbyname_process(struct autofs_cmd_ctx *cmdctx, valuelen = 1 + strlen(value); len = sizeof(uint32_t) + sizeof(uint32_t) + valuelen; - ret = sss_packet_grow(client->creq->out, len); + ret = sss_packet_grow(pctx->creq->out, len); if (ret != EOK) { goto done; } - sss_packet_get_body(client->creq->out, &body, &blen); + sss_packet_get_body(pctx->creq->out, &body, &blen); rp = 0; SAFEALIGN_SET_UINT32(&body[rp], len, &rp); @@ -1428,8 +1445,8 @@ getautomntbyname_process(struct autofs_cmd_ctx *cmdctx, ret = EOK; done: - sss_packet_set_error(client->creq->out, ret); - sss_cmd_done(client, cmdctx); + sss_packet_set_error(pctx->creq->out, ret); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -1437,14 +1454,17 @@ done: static int sss_autofs_cmd_endautomntent(struct cli_ctx *client) { + struct cli_protocol *pctx; errno_t ret; DEBUG(SSSDBG_TRACE_FUNC, "endautomntent called\n"); + pctx = talloc_get_type(client->protocol_ctx, struct cli_protocol); + /* create response packet */ - ret = sss_packet_new(client->creq, 0, - sss_packet_get_cmd(client->creq->in), - &client->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; @@ -1476,3 +1496,18 @@ struct sss_cmd_table *get_autofs_cmds(void) return autofs_cmds; } + +int autofs_connection_setup(struct cli_ctx *cctx) +{ + int ret; + + ret = sss_connection_setup(cctx); + if (ret != EOK) return ret; + + cctx->state_ctx = talloc_zero(cctx, struct autofs_state_ctx); + if (!cctx->state_ctx) { + return ENOMEM; + } + + return EOK; +} diff --git a/src/responder/common/responder.h b/src/responder/common/responder.h index 2b5b05412..adf7b3d19 100644 --- a/src/responder/common/responder.h +++ b/src/responder/common/responder.h @@ -68,6 +68,11 @@ struct cli_protocol_version { const char *description; }; +struct cli_protocol { + struct cli_request *creq; + struct cli_protocol_version *cli_protocol_version; +}; + struct resp_ctx; struct be_conn { @@ -130,26 +135,14 @@ struct cli_ctx { struct resp_ctx *rctx; int cfd; struct tevent_fd *cfde; + tevent_fd_handler_t cfd_handler; struct sockaddr_un addr; - struct cli_request *creq; - struct cli_protocol_version *cli_protocol_version; int priv; struct cli_creds *creds; - int pwent_dom_idx; - int pwent_cur; - - int grent_dom_idx; - int grent_cur; - - int svc_dom_idx; - int svcent_cur; - - char *netgr_name; - int netgrent_cur; - - char *automntmap_name; + void *protocol_ctx; + void *state_ctx; struct tevent_timer *idle; }; @@ -165,6 +158,12 @@ struct mon_cli_iface; /* * responder_common.c * + */ + +typedef int (*connection_setup_t)(struct cli_ctx *cctx); + +int sss_connection_setup(struct cli_ctx *cctx); +/* * NOTE: We would like to use more strong typing for the @dp_vtable argument * but can't since it accepts either a struct data_provider_iface * or struct data_provider_rev_iface. So pass the base struct: sbus_vtable @@ -183,6 +182,7 @@ int sss_process_init(TALLOC_CTX *mem_ctx, struct mon_cli_iface *monitor_intf, const char *cli_name, struct sbus_vtable *dp_intf, + connection_setup_t conn_setup, struct resp_ctx **responder_ctx); int sss_dp_get_domain_conn(struct resp_ctx *rctx, const char *domain, diff --git a/src/responder/common/responder_cmd.c b/src/responder/common/responder_cmd.c index 1ac86fddf..175a8e5d6 100644 --- a/src/responder/common/responder_cmd.c +++ b/src/responder/common/responder_cmd.c @@ -24,20 +24,25 @@ #include "responder/common/responder.h" #include "responder/common/responder_packet.h" + int sss_cmd_send_error(struct cli_ctx *cctx, int err) { + struct cli_protocol *pctx; int ret; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + if (!pctx) return EINVAL; + /* create response packet */ - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { DEBUG(SSSDBG_CRIT_FAILURE, "Cannot create new packet: %d\n", ret); return ret; } - sss_packet_set_error(cctx->creq->out, err); + sss_packet_set_error(pctx->creq->out, err); return EOK; } @@ -63,22 +68,26 @@ int sss_cmd_empty_packet(struct sss_packet *packet) int sss_cmd_send_empty(struct cli_ctx *cctx, TALLOC_CTX *freectx) { + struct cli_protocol *pctx; int ret; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + if (!pctx) return EINVAL; + /* create response packet */ - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } - ret = sss_cmd_empty_packet(cctx->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); if (ret != EOK) { return ret; } - sss_packet_set_error(cctx->creq->out, EOK); + sss_packet_set_error(pctx->creq->out, EOK); sss_cmd_done(cctx, freectx); return EOK; } @@ -95,6 +104,7 @@ void sss_cmd_done(struct cli_ctx *cctx, void *freectx) int sss_cmd_get_version(struct cli_ctx *cctx) { + struct cli_protocol *pctx; uint8_t *req_body; size_t req_blen; uint8_t *body; @@ -105,16 +115,19 @@ int sss_cmd_get_version(struct cli_ctx *cctx) int i; static struct cli_protocol_version *cli_protocol_version = NULL; - cctx->cli_protocol_version = NULL; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + if (!pctx) return EINVAL; + + pctx->cli_protocol_version = NULL; if (cli_protocol_version == NULL) { cli_protocol_version = register_cli_protocol_version(); } if (cli_protocol_version != NULL) { - cctx->cli_protocol_version = &cli_protocol_version[0]; + pctx->cli_protocol_version = &cli_protocol_version[0]; - sss_packet_get_body(cctx->creq->in, &req_body, &req_blen); + sss_packet_get_body(pctx->creq->in, &req_body, &req_blen); if (req_blen == sizeof(uint32_t)) { memcpy(&client_version, req_body, sizeof(uint32_t)); DEBUG(SSSDBG_FUNC_DATA, @@ -123,7 +136,7 @@ int sss_cmd_get_version(struct cli_ctx *cctx) i=0; while(cli_protocol_version[i].version>0) { if (cli_protocol_version[i].version == client_version) { - cctx->cli_protocol_version = &cli_protocol_version[i]; + pctx->cli_protocol_version = &cli_protocol_version[i]; break; } i++; @@ -132,16 +145,16 @@ int sss_cmd_get_version(struct cli_ctx *cctx) } /* create response packet */ - ret = sss_packet_new(cctx->creq, sizeof(uint32_t), - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, sizeof(uint32_t), + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } - sss_packet_get_body(cctx->creq->out, &body, &blen); + sss_packet_get_body(pctx->creq->out, &body, &blen); - protocol_version = (cctx->cli_protocol_version != NULL) - ? cctx->cli_protocol_version->version : 0; + protocol_version = (pctx->cli_protocol_version != NULL) + ? pctx->cli_protocol_version->version : 0; SAFEALIGN_COPY_UINT32(body, &protocol_version, NULL); DEBUG(SSSDBG_FUNC_DATA, "Offered version [%d].\n", protocol_version); diff --git a/src/responder/common/responder_common.c b/src/responder/common/responder_common.c index 67a7aa8f9..e67456bb1 100644 --- a/src/responder/common/responder_common.c +++ b/src/responder/common/responder_common.c @@ -241,12 +241,14 @@ done: return ret; } - static void client_send(struct cli_ctx *cctx) { + struct cli_protocol *pctx; int ret; - ret = sss_packet_send(cctx->creq->out, cctx->cfd); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + + ret = sss_packet_send(pctx->creq->out, cctx->cfd); if (ret == EAGAIN) { /* not all data was sent, loop again */ return; @@ -260,26 +262,30 @@ static void client_send(struct cli_ctx *cctx) /* ok all sent */ TEVENT_FD_NOT_WRITEABLE(cctx->cfde); TEVENT_FD_READABLE(cctx->cfde); - talloc_free(cctx->creq); - cctx->creq = NULL; + talloc_zfree(pctx->creq); return; } static int client_cmd_execute(struct cli_ctx *cctx, struct sss_cmd_table *sss_cmds) { + struct cli_protocol *pctx; enum sss_cli_command cmd; - cmd = sss_packet_get_cmd(cctx->creq->in); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + cmd = sss_packet_get_cmd(pctx->creq->in); return sss_cmd_execute(cctx, cmd, sss_cmds); } static void client_recv(struct cli_ctx *cctx) { + struct cli_protocol *pctx; int ret; - if (!cctx->creq) { - cctx->creq = talloc_zero(cctx, struct cli_request); - if (!cctx->creq) { + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + + if (!pctx->creq) { + pctx->creq = talloc_zero(cctx, struct cli_request); + if (!pctx->creq) { DEBUG(SSSDBG_FATAL_FAILURE, "Failed to alloc request, aborting client!\n"); talloc_free(cctx); @@ -287,9 +293,9 @@ static void client_recv(struct cli_ctx *cctx) } } - if (!cctx->creq->in) { - ret = sss_packet_new(cctx->creq, SSS_PACKET_MAX_RECV_SIZE, - 0, &cctx->creq->in); + if (!pctx->creq->in) { + ret = sss_packet_new(pctx->creq, SSS_PACKET_MAX_RECV_SIZE, + 0, &pctx->creq->in); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "Failed to alloc request, aborting client!\n"); @@ -298,7 +304,7 @@ static void client_recv(struct cli_ctx *cctx) } } - ret = sss_packet_recv(cctx->creq->in, cctx->cfd); + ret = sss_packet_recv(pctx->creq->in, cctx->cfd); switch (ret) { case EOK: /* do not read anymore */ @@ -368,6 +374,7 @@ static void client_fd_handler(struct tevent_context *ev, struct accept_fd_ctx { struct resp_ctx *rctx; bool is_private; + connection_setup_t connection_setup; }; static void idle_handler(struct tevent_context *ev, @@ -468,8 +475,19 @@ static void accept_fd_handler(struct tevent_context *ev, } } + ret = accept_ctx->connection_setup(cctx); + if (ret != EOK) { + close(cctx->cfd); + talloc_free(cctx); + DEBUG(SSSDBG_OP_FAILURE, + "Failed to setup client handler%s\n", + accept_ctx->is_private ? " on privileged pipe" : ""); + return; + } + cctx->cfde = tevent_add_fd(ev, cctx, cctx->cfd, - TEVENT_FD_READ, client_fd_handler, cctx); + TEVENT_FD_READ, cctx->cfd_handler, + cctx); if (!cctx->cfde) { close(cctx->cfd); talloc_free(cctx); @@ -644,10 +662,11 @@ done: } /* create a unix socket and listen to it */ -static int set_unix_socket(struct resp_ctx *rctx) +static int set_unix_socket(struct resp_ctx *rctx, + connection_setup_t conn_setup) { errno_t ret; - struct accept_fd_ctx *accept_ctx; + struct accept_fd_ctx *accept_ctx = NULL; /* for future use */ #if 0 @@ -699,6 +718,7 @@ static int set_unix_socket(struct resp_ctx *rctx) if(!accept_ctx) goto failed; accept_ctx->rctx = rctx; accept_ctx->is_private = false; + accept_ctx->connection_setup = conn_setup; rctx->lfde = tevent_add_fd(rctx->ev, rctx, rctx->lfd, TEVENT_FD_READ, accept_fd_handler, @@ -723,6 +743,7 @@ static int set_unix_socket(struct resp_ctx *rctx) if(!accept_ctx) goto failed; accept_ctx->rctx = rctx; accept_ctx->is_private = true; + accept_ctx->connection_setup = conn_setup; rctx->priv_lfde = tevent_add_fd(rctx->ev, rctx, rctx->priv_lfd, TEVENT_FD_READ, accept_fd_handler, @@ -742,6 +763,18 @@ failed: return EIO; } +int sss_connection_setup(struct cli_ctx *cctx) +{ + cctx->protocol_ctx = talloc_zero(cctx, struct cli_protocol); + if (!cctx->protocol_ctx) { + return ENOMEM; + } + + cctx->cfd_handler = client_fd_handler; + + return EOK; +} + static int sss_responder_ctx_destructor(void *ptr) { struct resp_ctx *rctx = talloc_get_type(ptr, struct resp_ctx); @@ -829,6 +862,7 @@ int sss_process_init(TALLOC_CTX *mem_ctx, struct mon_cli_iface *monitor_intf, const char *cli_name, struct sbus_vtable *dp_intf, + connection_setup_t conn_setup, struct resp_ctx **responder_ctx) { struct resp_ctx *rctx; @@ -959,7 +993,7 @@ int sss_process_init(TALLOC_CTX *mem_ctx, } /* after all initializations we are ready to listen on our socket */ - ret = set_unix_socket(rctx); + ret = set_unix_socket(rctx, conn_setup); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "fatal error initializing socket\n"); goto fail; diff --git a/src/responder/ifp/ifpsrv.c b/src/responder/ifp/ifpsrv.c index fdb11f650..a2137ecb2 100644 --- a/src/responder/ifp/ifpsrv.c +++ b/src/responder/ifp/ifpsrv.c @@ -239,6 +239,7 @@ int ifp_process_init(TALLOC_CTX *mem_ctx, &monitor_ifp_methods, "InfoPipe", &ifp_dp_methods.vtable, + sss_connection_setup, &rctx); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "sss_process_init() failed\n"); diff --git a/src/responder/nss/nsssrv.c b/src/responder/nss/nsssrv.c index d01884789..8be3455e5 100644 --- a/src/responder/nss/nsssrv.c +++ b/src/responder/nss/nsssrv.c @@ -420,6 +420,7 @@ int nss_process_init(TALLOC_CTX *mem_ctx, NSS_SBUS_SERVICE_VERSION, &monitor_nss_methods, "NSS", &nss_dp_methods.vtable, + nss_connection_setup, &rctx); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "sss_process_init() failed\n"); diff --git a/src/responder/nss/nsssrv_cmd.c b/src/responder/nss/nsssrv_cmd.c index 535ba933e..9ba81a6aa 100644 --- a/src/responder/nss/nsssrv_cmd.c +++ b/src/responder/nss/nsssrv_cmd.c @@ -547,28 +547,30 @@ static int nss_cmd_getpw_send_reply(struct nss_dom_ctx *dctx, bool filter) { struct nss_cmd_ctx *cmdctx = dctx->cmdctx; struct cli_ctx *cctx = cmdctx->cctx; + struct cli_protocol *pctx; struct nss_ctx *nctx; int ret; int i; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return EFAULT; } i = dctx->res->count; - ret = fill_pwent(cctx->creq->out, + ret = fill_pwent(pctx->creq->out, dctx->domain, nctx, filter, true, dctx->res->msgs, &i); if (ret) { return ret; } - sss_packet_set_error(cctx->creq->out, EOK); + sss_packet_set_error(pctx->creq->out, EOK); sss_cmd_done(cctx, cmdctx); return EOK; } @@ -1403,7 +1405,7 @@ static int nss_check_name_of_well_known_sid(struct nss_cmd_ctx *cmdctx, struct sized_string sid; uint8_t *body; size_t blen; - struct cli_ctx *cctx; + struct cli_protocol *pctx; struct nss_ctx *nss_ctx; size_t pctr = 0; @@ -1434,22 +1436,22 @@ static int nss_check_name_of_well_known_sid(struct nss_cmd_ctx *cmdctx, to_sized_string(&sid, wk_sid); - cctx = cmdctx->cctx; - ret = sss_packet_new(cctx->creq, sid.len + 3 * sizeof(uint32_t), - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + ret = sss_packet_new(pctx->creq, sid.len + 3 * sizeof(uint32_t), + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ENOMEM; } - sss_packet_get_body(cctx->creq->out, &body, &blen); + sss_packet_get_body(pctx->creq->out, &body, &blen); SAFEALIGN_SETMEM_UINT32(body, 1, &pctr); /* num results */ SAFEALIGN_SETMEM_UINT32(body + pctr, 0, &pctr); /* reserved */ SAFEALIGN_SETMEM_UINT32(body + pctr, SSS_ID_TYPE_GID, &pctr); memcpy(&body[pctr], sid.str, sid.len); - sss_packet_set_error(cctx->creq->out, EOK); - sss_cmd_done(cctx, cmdctx); + sss_packet_set_error(pctx->creq->out, EOK); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -1464,6 +1466,7 @@ static int nss_cmd_getbynam(enum sss_cli_command cmd, struct cli_ctx *cctx) { struct tevent_req *req; + struct cli_protocol *pctx; struct nss_cmd_ctx *cmdctx; struct nss_dom_ctx *dctx; const char *rawname; @@ -1499,8 +1502,10 @@ static int nss_cmd_getbynam(enum sss_cli_command cmd, struct cli_ctx *cctx) } dctx->cmdctx = cmdctx; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + /* get user name to query */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); /* if not terminated fail */ if (body[blen -1] != '\0') { @@ -1894,6 +1899,7 @@ static int nss_cmd_getpwuid(struct cli_ctx *cctx) static int nss_cmd_getbyid(enum sss_cli_command cmd, struct cli_ctx *cctx) { + struct cli_protocol *pctx; struct nss_cmd_ctx *cmdctx; struct nss_dom_ctx *dctx; struct nss_ctx *nctx; @@ -1929,8 +1935,10 @@ static int nss_cmd_getbyid(enum sss_cli_command cmd, struct cli_ctx *cctx) } dctx->cmdctx = cmdctx; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + /* get id to query */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); if (blen != sizeof(uint32_t)) { ret = EINVAL; @@ -2151,6 +2159,7 @@ struct tevent_req *nss_cmd_setpwent_send(TALLOC_CTX *mem_ctx, { errno_t ret; struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; struct tevent_req *req; struct setent_ctx *state; struct sss_domain_info *dom; @@ -2158,10 +2167,11 @@ struct tevent_req *nss_cmd_setpwent_send(TALLOC_CTX *mem_ctx, DEBUG(SSSDBG_CONF_SETTINGS, "Received setpwent request\n"); nctx = talloc_get_type(client->rctx->pvt_ctx, struct nss_ctx); + state_ctx = talloc_get_type(client->state_ctx, struct nss_state_ctx); /* Reset the read pointers */ - client->pwent_dom_idx = 0; - client->pwent_cur = 0; + state_ctx->pwent.dom_idx = 0; + state_ctx->pwent.cur = 0; req = tevent_req_create(mem_ctx, &state, struct setent_ctx); if (!req) { @@ -2456,17 +2466,19 @@ static errno_t nss_cmd_setpwent_recv(struct tevent_req *req) static void nss_cmd_setpwent_done(struct tevent_req *req) { + struct cli_protocol *pctx; errno_t ret; struct nss_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct nss_cmd_ctx); + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); ret = nss_cmd_setpwent_recv(req); talloc_zfree(req); if (ret == EOK || ret == ENOENT) { /* Either we succeeded or no domains were eligible */ - ret = sss_packet_new(cmdctx->cctx->creq, 0, - sss_packet_get_cmd(cmdctx->cctx->creq->in), - &cmdctx->cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret == EOK) { sss_cmd_done(cmdctx->cctx, cmdctx); return; @@ -2481,6 +2493,7 @@ static void nss_cmd_implicit_setpwent_done(struct tevent_req *req); static int nss_cmd_getpwent(struct cli_ctx *cctx) { struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; struct nss_cmd_ctx *cmdctx; struct tevent_req *req; @@ -2491,14 +2504,15 @@ static int nss_cmd_getpwent(struct cli_ctx *cctx) return ENOMEM; } cmdctx->cctx = cctx; + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); /* Save the current index and cursor locations * If we end up calling setpwent implicitly, because the response object * expired and has to be recreated, we want to resume from the same * location. */ - cmdctx->saved_dom_idx = cctx->pwent_dom_idx; - cmdctx->saved_cur = cctx->pwent_cur; + cmdctx->saved_dom_idx = state_ctx->pwent.dom_idx; + cmdctx->saved_cur = state_ctx->pwent.cur; nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); if(!nctx->pctx || !nctx->pctx->ready) { @@ -2519,85 +2533,92 @@ static int nss_cmd_getpwent(struct cli_ctx *cctx) static int nss_cmd_retpwent(struct cli_ctx *cctx, int num); static int nss_cmd_getpwent_immediate(struct nss_cmd_ctx *cmdctx) { - struct cli_ctx *cctx = cmdctx->cctx; + struct cli_protocol *pctx; uint8_t *body; size_t blen; uint32_t num; int ret; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* get max num of entries to return in one call */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); if (blen != sizeof(uint32_t)) { return EINVAL; } SAFEALIGN_COPY_UINT32(&num, body, NULL); /* create response packet */ - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } - ret = nss_cmd_retpwent(cctx, num); + ret = nss_cmd_retpwent(cmdctx->cctx, num); - sss_packet_set_error(cctx->creq->out, ret); - sss_cmd_done(cctx, cmdctx); + sss_packet_set_error(pctx->creq->out, ret); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } static int nss_cmd_retpwent(struct cli_ctx *cctx, int num) { + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; struct nss_ctx *nctx; - struct getent_ctx *pctx; + struct getent_ctx *gctx; struct ldb_message **msgs = NULL; struct dom_ctx *pdom = NULL; int n = 0; int ret = ENOENT; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); if (!nctx->pctx) goto none; - pctx = nctx->pctx; + gctx = nctx->pctx; while (ret == ENOENT) { - if (cctx->pwent_dom_idx >= pctx->num) break; + if (state_ctx->pwent.dom_idx >= gctx->num) break; - pdom = &pctx->doms[cctx->pwent_dom_idx]; + pdom = &gctx->doms[state_ctx->pwent.dom_idx]; - n = pdom->res->count - cctx->pwent_cur; - if (n <= 0 && (cctx->pwent_dom_idx+1 < pctx->num)) { - cctx->pwent_dom_idx++; - pdom = &pctx->doms[cctx->pwent_dom_idx]; + n = pdom->res->count - state_ctx->pwent.cur; + if (n <= 0 && (state_ctx->pwent.dom_idx+1 < gctx->num)) { + state_ctx->pwent.dom_idx++; + pdom = &gctx->doms[state_ctx->pwent.dom_idx]; n = pdom->res->count; - cctx->pwent_cur = 0; + state_ctx->pwent.cur = 0; } if (!n) break; if (n < 0) { - DEBUG(SSSDBG_CRIT_FAILURE, "BUG: Negative difference" - "[%d - %d = %d]\n", pdom->res->count, cctx->pwent_cur, n); + DEBUG(SSSDBG_CRIT_FAILURE, + "BUG: Negative difference[%d - %d = %d]\n", + pdom->res->count, state_ctx->pwent.cur, n); DEBUG(SSSDBG_CRIT_FAILURE, "Domain: %d (total %d)\n", - cctx->pwent_dom_idx, pctx->num); + state_ctx->pwent.dom_idx, gctx->num); break; } if (n > num) n = num; - msgs = &(pdom->res->msgs[cctx->pwent_cur]); + msgs = &(pdom->res->msgs[state_ctx->pwent.cur]); - ret = fill_pwent(cctx->creq->out, pdom->domain, nctx, + ret = fill_pwent(pctx->creq->out, pdom->domain, nctx, true, false, msgs, &n); - cctx->pwent_cur += n; + state_ctx->pwent.cur += n; } none: if (ret == ENOENT) { - ret = sss_cmd_empty_packet(cctx->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); } return ret; } @@ -2607,7 +2628,7 @@ static void nss_cmd_implicit_setpwent_done(struct tevent_req *req) errno_t ret; struct nss_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct nss_cmd_ctx); - + struct nss_state_ctx *state_ctx; ret = nss_cmd_setpwent_recv(req); talloc_zfree(req); @@ -2623,8 +2644,9 @@ static void nss_cmd_implicit_setpwent_done(struct tevent_req *req) } /* Restore the saved index and cursor locations */ - cmdctx->cctx->pwent_dom_idx = cmdctx->saved_dom_idx; - cmdctx->cctx->pwent_cur = cmdctx->saved_cur; + state_ctx = talloc_get_type(cmdctx->cctx->state_ctx, struct nss_state_ctx); + state_ctx->pwent.dom_idx = cmdctx->saved_dom_idx; + state_ctx->pwent.cur = cmdctx->saved_cur; ret = nss_cmd_getpwent_immediate(cmdctx); if (ret != EOK) { @@ -2637,17 +2659,21 @@ static void nss_cmd_implicit_setpwent_done(struct tevent_req *req) static int nss_cmd_endpwent(struct cli_ctx *cctx) { + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; struct nss_ctx *nctx; int ret; DEBUG(SSSDBG_CONF_SETTINGS, "Terminating request info for all accounts\n"); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); /* create response packet */ - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; @@ -2655,8 +2681,8 @@ static int nss_cmd_endpwent(struct cli_ctx *cctx) if (nctx->pctx == NULL) goto done; /* Reset the indices so that subsequent requests start at zero */ - cctx->pwent_dom_idx = 0; - cctx->pwent_cur = 0; + state_ctx->pwent.dom_idx = 0; + state_ctx->pwent.cur = 0; done: sss_cmd_done(cctx, NULL); @@ -3134,29 +3160,30 @@ done: static int nss_cmd_getgr_send_reply(struct nss_dom_ctx *dctx, bool filter) { struct nss_cmd_ctx *cmdctx = dctx->cmdctx; - struct cli_ctx *cctx = cmdctx->cctx; + struct cli_protocol *pctx; struct nss_ctx *nctx; int ret; int i; - nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + nctx = talloc_get_type(cmdctx->cctx->rctx->pvt_ctx, struct nss_ctx); - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return EFAULT; } i = dctx->res->count; - ret = fill_grent(cctx->creq->out, + ret = fill_grent(pctx->creq->out, dctx->domain, nctx, filter, true, dctx->res->msgs, &i); if (ret) { return ret; } - sss_packet_set_error(cctx->creq->out, EOK); - sss_cmd_done(cctx, cmdctx); + sss_packet_set_error(pctx->creq->out, EOK); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -3511,6 +3538,7 @@ struct tevent_req *nss_cmd_setgrent_send(TALLOC_CTX *mem_ctx, { errno_t ret; struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; struct tevent_req *req; struct setent_ctx *state; struct sss_domain_info *dom; @@ -3518,10 +3546,11 @@ struct tevent_req *nss_cmd_setgrent_send(TALLOC_CTX *mem_ctx, DEBUG(SSSDBG_CONF_SETTINGS, "Received setgrent request\n"); nctx = talloc_get_type(client->rctx->pvt_ctx, struct nss_ctx); + state_ctx = talloc_get_type(client->state_ctx, struct nss_state_ctx); /* Reset the read pointers */ - client->grent_dom_idx = 0; - client->grent_cur = 0; + state_ctx->grent.dom_idx = 0; + state_ctx->grent.cur = 0; req = tevent_req_create(mem_ctx, &state, struct setent_ctx); if (!req) { @@ -3820,14 +3849,16 @@ static void nss_cmd_setgrent_done(struct tevent_req *req) errno_t ret; struct nss_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct nss_cmd_ctx); + struct cli_protocol *pctx; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); ret = nss_cmd_setgrent_recv(req); talloc_zfree(req); if (ret == EOK || ret == ENOENT) { /* Either we succeeded or no domains were eligible */ - ret = sss_packet_new(cmdctx->cctx->creq, 0, - sss_packet_get_cmd(cmdctx->cctx->creq->in), - &cmdctx->cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret == EOK) { sss_cmd_done(cmdctx->cctx, cmdctx); return; @@ -3840,6 +3871,8 @@ static void nss_cmd_setgrent_done(struct tevent_req *req) static int nss_cmd_retgrent(struct cli_ctx *cctx, int num) { + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; struct nss_ctx *nctx; struct getent_ctx *gctx; struct ldb_message **msgs = NULL; @@ -3847,71 +3880,75 @@ static int nss_cmd_retgrent(struct cli_ctx *cctx, int num) int n = 0; int ret = ENOENT; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); if (!nctx->gctx) goto none; gctx = nctx->gctx; while (ret == ENOENT) { - if (cctx->grent_dom_idx >= gctx->num) break; + if (state_ctx->grent.dom_idx >= gctx->num) break; - gdom = &gctx->doms[cctx->grent_dom_idx]; + gdom = &gctx->doms[state_ctx->grent.dom_idx]; - n = gdom->res->count - cctx->grent_cur; - if (n <= 0 && (cctx->grent_dom_idx+1 < gctx->num)) { - cctx->grent_dom_idx++; - gdom = &gctx->doms[cctx->grent_dom_idx]; + n = gdom->res->count - state_ctx->grent.cur; + if (n <= 0 && (state_ctx->grent.dom_idx+1 < gctx->num)) { + state_ctx->grent.dom_idx++; + gdom = &gctx->doms[state_ctx->grent.dom_idx]; n = gdom->res->count; - cctx->grent_cur = 0; + state_ctx->grent.cur = 0; } if (!n) break; if (n > num) n = num; - msgs = &(gdom->res->msgs[cctx->grent_cur]); + msgs = &(gdom->res->msgs[state_ctx->grent.cur]); - ret = fill_grent(cctx->creq->out, + ret = fill_grent(pctx->creq->out, gdom->domain, nctx, true, false, msgs, &n); - cctx->grent_cur += n; + state_ctx->grent.cur += n; } none: if (ret == ENOENT) { - ret = sss_cmd_empty_packet(cctx->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); } return ret; } static int nss_cmd_getgrent_immediate(struct nss_cmd_ctx *cmdctx) { - struct cli_ctx *cctx = cmdctx->cctx; + struct cli_protocol *pctx; uint8_t *body; size_t blen; uint32_t num; int ret; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* get max num of entries to return in one call */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); if (blen != sizeof(uint32_t)) { return EINVAL; } SAFEALIGN_COPY_UINT32(&num, body, NULL); /* create response packet */ - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } - ret = nss_cmd_retgrent(cctx, num); + ret = nss_cmd_retgrent(cmdctx->cctx, num); - sss_packet_set_error(cctx->creq->out, ret); - sss_cmd_done(cctx, cmdctx); + sss_packet_set_error(pctx->creq->out, ret); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -3920,6 +3957,7 @@ static void nss_cmd_implicit_setgrent_done(struct tevent_req *req); static int nss_cmd_getgrent(struct cli_ctx *cctx) { struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; struct nss_cmd_ctx *cmdctx; struct tevent_req *req; @@ -3936,8 +3974,9 @@ static int nss_cmd_getgrent(struct cli_ctx *cctx) * expired and has to be recreated, we want to resume from the same * location. */ - cmdctx->saved_dom_idx = cctx->grent_dom_idx; - cmdctx->saved_cur = cctx->grent_cur; + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); + cmdctx->saved_dom_idx = state_ctx->grent.dom_idx; + cmdctx->saved_cur = state_ctx->grent.cur; nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); if(!nctx->gctx || !nctx->gctx->ready) { @@ -3960,6 +3999,7 @@ static void nss_cmd_implicit_setgrent_done(struct tevent_req *req) errno_t ret; struct nss_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct nss_cmd_ctx); + struct nss_state_ctx *state_ctx; ret = nss_cmd_setgrent_recv(req); talloc_zfree(req); @@ -3976,8 +4016,9 @@ static void nss_cmd_implicit_setgrent_done(struct tevent_req *req) } /* Restore the saved index and cursor locations */ - cmdctx->cctx->grent_dom_idx = cmdctx->saved_dom_idx; - cmdctx->cctx->grent_cur = cmdctx->saved_cur; + state_ctx = talloc_get_type(cmdctx->cctx->state_ctx, struct nss_state_ctx); + state_ctx->grent.dom_idx = cmdctx->saved_dom_idx; + state_ctx->grent.cur = cmdctx->saved_cur; ret = nss_cmd_getgrent_immediate(cmdctx); if (ret != EOK) { @@ -3990,17 +4031,21 @@ static void nss_cmd_implicit_setgrent_done(struct tevent_req *req) static int nss_cmd_endgrent(struct cli_ctx *cctx) { + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; struct nss_ctx *nctx; int ret; DEBUG(SSSDBG_CONF_SETTINGS, "Terminating request info for all groups\n"); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); /* create response packet */ - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; @@ -4008,8 +4053,8 @@ static int nss_cmd_endgrent(struct cli_ctx *cctx) if (nctx->gctx == NULL) goto done; /* Reset the indices so that subsequent requests start at zero */ - cctx->grent_dom_idx = 0; - cctx->grent_cur = 0; + state_ctx->grent.dom_idx = 0; + state_ctx->grent.cur = 0; done: sss_cmd_done(cctx, NULL); @@ -4255,26 +4300,27 @@ static int fill_initgr(struct sss_packet *packet, static int nss_cmd_initgr_send_reply(struct nss_dom_ctx *dctx) { struct nss_cmd_ctx *cmdctx = dctx->cmdctx; - struct cli_ctx *cctx = cmdctx->cctx; + struct cli_protocol *pctx; struct nss_ctx *nctx; int ret; - nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + nctx = talloc_get_type(cmdctx->cctx->rctx->pvt_ctx, struct nss_ctx); - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return EFAULT; } - ret = fill_initgr(cctx->creq->out, dctx->domain, dctx->res, nctx, + ret = fill_initgr(pctx->creq->out, dctx->domain, dctx->res, nctx, dctx->mc_name, cmdctx->normalized_name); if (ret) { return ret; } - sss_packet_set_error(cctx->creq->out, EOK); - sss_cmd_done(cctx, cmdctx); + sss_packet_set_error(pctx->creq->out, EOK); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -5267,6 +5313,7 @@ static errno_t nss_cmd_getbysid_send_reply(struct nss_dom_ctx *dctx) { struct nss_cmd_ctx *cmdctx = dctx->cmdctx; struct cli_ctx *cctx = cmdctx->cctx; + struct cli_protocol *pctx; int ret; enum sss_id_type id_type; @@ -5276,9 +5323,11 @@ static errno_t nss_cmd_getbysid_send_reply(struct nss_dom_ctx *dctx) return ENOENT; } - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return EFAULT; } @@ -5291,21 +5340,21 @@ static errno_t nss_cmd_getbysid_send_reply(struct nss_dom_ctx *dctx) switch(cmdctx->cmd) { case SSS_NSS_GETNAMEBYSID: - ret = fill_name(cctx->creq->out, + ret = fill_name(pctx->creq->out, dctx->domain, id_type, true, dctx->res->msgs[0]); break; case SSS_NSS_GETIDBYSID: - ret = fill_id(cctx->creq->out, id_type, dctx->res->msgs[0]); + ret = fill_id(pctx->creq->out, id_type, dctx->res->msgs[0]); break; case SSS_NSS_GETSIDBYNAME: case SSS_NSS_GETSIDBYID: - ret = fill_sid(cctx->creq->out, id_type, dctx->res->msgs[0]); + ret = fill_sid(pctx->creq->out, id_type, dctx->res->msgs[0]); break; case SSS_NSS_GETORIGBYNAME: - ret = fill_orig(cctx->creq->out, cctx->rctx, id_type, + ret = fill_orig(pctx->creq->out, cctx->rctx, id_type, dctx->res->msgs[0]); break; default: @@ -5316,7 +5365,7 @@ static errno_t nss_cmd_getbysid_send_reply(struct nss_dom_ctx *dctx) return ret; } - sss_packet_set_error(cctx->creq->out, EOK); + sss_packet_set_error(pctx->creq->out, EOK); sss_cmd_done(cctx, cmdctx); return EOK; } @@ -5330,10 +5379,12 @@ static int nss_check_well_known_sid(struct nss_cmd_ctx *cmdctx) struct sized_string name; uint8_t *body; size_t blen; - struct cli_ctx *cctx; + struct cli_protocol *pctx; struct nss_ctx *nss_ctx; size_t pctr = 0; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + ret = well_known_sid_to_name(cmdctx->secid, &wk_dom_name, &wk_name); if (ret != EOK) { DEBUG(SSSDBG_TRACE_ALL, "SID [%s] is not a Well-Known SID.\n", @@ -5360,23 +5411,22 @@ static int nss_check_well_known_sid(struct nss_cmd_ctx *cmdctx) to_sized_string(&name, wk_name); } - cctx = cmdctx->cctx; - ret = sss_packet_new(cctx->creq, name.len + 3 * sizeof(uint32_t), - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, name.len + 3 * sizeof(uint32_t), + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { talloc_free(fq_name); return ENOMEM; } - sss_packet_get_body(cctx->creq->out, &body, &blen); + sss_packet_get_body(pctx->creq->out, &body, &blen); SAFEALIGN_SETMEM_UINT32(body, 1, &pctr); /* num results */ SAFEALIGN_SETMEM_UINT32(body + pctr, 0, &pctr); /* reserved */ SAFEALIGN_SETMEM_UINT32(body + pctr, SSS_ID_TYPE_GID, &pctr); memcpy(&body[pctr], name.str, name.len); - sss_packet_set_error(cctx->creq->out, EOK); - sss_cmd_done(cctx, cmdctx); + sss_packet_set_error(pctx->creq->out, EOK); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -5390,6 +5440,7 @@ static int nss_cmd_getbysid(enum sss_cli_command cmd, struct cli_ctx *cctx) uint8_t *body; size_t blen; int ret; + struct cli_protocol *pctx; struct nss_ctx *nctx; enum idmap_error_code err; uint8_t *bin_sid = NULL; @@ -5401,6 +5452,8 @@ static int nss_cmd_getbysid(enum sss_cli_command cmd, struct cli_ctx *cctx) return EINVAL; } + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + cmdctx = talloc_zero(cctx, struct nss_cmd_ctx); if (!cmdctx) { return ENOMEM; @@ -5416,7 +5469,7 @@ static int nss_cmd_getbysid(enum sss_cli_command cmd, struct cli_ctx *cctx) dctx->cmdctx = cmdctx; /* get SID to query */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); /* if not terminated fail */ if (body[blen -1] != '\0') { @@ -5507,8 +5560,10 @@ static int nss_cmd_getbycert(enum sss_cli_command cmd, struct cli_ctx *cctx) char *pem_cert = NULL; size_t pem_size; struct nss_ctx *nctx; + struct cli_protocol *pctx; nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); if (cmd != SSS_NSS_GETNAMEBYCERT) { DEBUG(SSSDBG_CRIT_FAILURE, "Invalid command type [%d][%s].\n", @@ -5517,7 +5572,7 @@ static int nss_cmd_getbycert(enum sss_cli_command cmd, struct cli_ctx *cctx) } /* get certificate to query */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); /* if not terminated fail */ if (body[blen - 1] != '\0') { @@ -5551,9 +5606,11 @@ static void users_find_by_cert_done(struct tevent_req *req) struct cli_ctx *cctx; struct sss_domain_info *domain; struct ldb_result *result; + struct cli_protocol *pctx; errno_t ret; cctx = tevent_req_callback_data(req, struct cli_ctx); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); ret = cache_req_user_by_cert_recv(cctx, req, &result, &domain, NULL); talloc_zfree(req); @@ -5572,16 +5629,16 @@ static void users_find_by_cert_done(struct tevent_req *req) goto done; } - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { DEBUG(SSSDBG_OP_FAILURE, "sss_packet_new failed.\n"); ret = EFAULT; goto done; } - ret = fill_name(cctx->creq->out, domain, SSS_ID_TYPE_UID, true, + ret = fill_name(pctx->creq->out, domain, SSS_ID_TYPE_UID, true, result->msgs[0]); if (ret != EOK) { DEBUG(SSSDBG_OP_FAILURE, "fill_name failed.\n"); @@ -5592,7 +5649,7 @@ static void users_find_by_cert_done(struct tevent_req *req) done: if (ret == EOK) { - sss_packet_set_error(cctx->creq->out, EOK); + sss_packet_set_error(pctx->creq->out, EOK); sss_cmd_done(cctx, NULL); } else if (ret == ENOENT) { sss_cmd_send_empty(cctx, NULL); @@ -5676,3 +5733,18 @@ static struct sss_cmd_table nss_cmds[] = { struct sss_cmd_table *get_nss_cmds(void) { return nss_cmds; } + +int nss_connection_setup(struct cli_ctx *cctx) +{ + int ret; + + ret = sss_connection_setup(cctx); + if (ret != EOK) return ret; + + cctx->state_ctx = talloc_zero(cctx, struct nss_state_ctx); + if (!cctx->state_ctx) { + return ENOMEM; + } + + return EOK; +} diff --git a/src/responder/nss/nsssrv_netgroup.c b/src/responder/nss/nsssrv_netgroup.c index a3c74a3fd..e42976b24 100644 --- a/src/responder/nss/nsssrv_netgroup.c +++ b/src/responder/nss/nsssrv_netgroup.c @@ -94,6 +94,8 @@ static void nss_cmd_setnetgrent_done(struct tevent_req *req); int nss_cmd_setnetgrent(struct cli_ctx *client) { struct nss_cmd_ctx *cmdctx; + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; struct tevent_req *req; const char *rawname; uint8_t *body; @@ -101,7 +103,8 @@ int nss_cmd_setnetgrent(struct cli_ctx *client) errno_t ret = EOK; /* Reset the result cursor to zero */ - client->netgrent_cur = 0; + state_ctx = talloc_get_type(client->state_ctx, struct nss_state_ctx); + state_ctx->netgrent_cur = 0; cmdctx = talloc_zero(client, struct nss_cmd_ctx); if (!cmdctx) { @@ -109,8 +112,10 @@ int nss_cmd_setnetgrent(struct cli_ctx *client) } cmdctx->cctx = client; + pctx = talloc_get_type(client->protocol_ctx, struct cli_protocol); + /* get netgroup name to query */ - sss_packet_get_body(client->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); /* if not terminated fail */ if (body[blen -1] != '\0') { @@ -184,10 +189,12 @@ static struct tevent_req *setnetgrent_send(TALLOC_CTX *mem_ctx, struct tevent_req *req; struct setnetgrent_ctx *state; struct nss_dom_ctx *dctx; - struct cli_ctx *client = cmdctx->cctx; - struct nss_ctx *nctx = - talloc_get_type(client->rctx->pvt_ctx, struct nss_ctx); + struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; + + nctx = talloc_get_type(client->rctx->pvt_ctx, struct nss_ctx); + state_ctx = talloc_get_type(client->state_ctx, struct nss_state_ctx); req = tevent_req_create(mem_ctx, &state, struct setnetgrent_ctx); if (!req) { @@ -227,8 +234,8 @@ static struct tevent_req *setnetgrent_send(TALLOC_CTX *mem_ctx, } /* Save the netgroup name for getnetgrent */ - client->netgr_name = talloc_strdup(client, state->netgr_shortname); - if (!client->netgr_name) { + state_ctx->netgr_name = talloc_strdup(client, state->netgr_shortname); + if (!state_ctx->netgr_name) { ret = ENOMEM; goto error; } @@ -238,8 +245,8 @@ static struct tevent_req *setnetgrent_send(TALLOC_CTX *mem_ctx, cmdctx->check_next = true; /* Save the netgroup name for getnetgrent */ - client->netgr_name = talloc_strdup(client, rawname); - if (!client->netgr_name) { + state_ctx->netgr_name = talloc_strdup(client, rawname); + if (!state_ctx->netgr_name) { ret = ENOMEM; goto error; } @@ -272,6 +279,7 @@ static errno_t setnetgrent_retry(struct tevent_req *req) struct setnetgrent_ctx *state; struct cli_ctx *client; struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; struct nss_cmd_ctx *cmdctx; struct nss_dom_ctx *dctx; @@ -280,13 +288,14 @@ static errno_t setnetgrent_retry(struct tevent_req *req) cmdctx = state->cmdctx; client = cmdctx->cctx; nctx = talloc_get_type(client->rctx->pvt_ctx, struct nss_ctx); + state_ctx = talloc_get_type(client->state_ctx, struct nss_state_ctx); dctx->check_provider = NEED_CHECK_PROVIDER(dctx->domain->provider); /* Is the result context already available? * Check for existing lookups for this netgroup */ - ret = get_netgroup_entry(nctx, client->netgr_name, &state->netgr); + ret = get_netgroup_entry(nctx, state_ctx->netgr_name, &state->netgr); if (ret == EOK) { /* Another process already requested this netgroup * Check whether it's ready for processing. @@ -328,7 +337,7 @@ static errno_t setnetgrent_retry(struct tevent_req *req) * so we can remove it in the destructor */ state->netgr->name = talloc_strdup(state->netgr, - client->netgr_name); + state_ctx->netgr_name); if (!state->netgr->name) { talloc_free(state->netgr); ret = ENOMEM; @@ -718,7 +727,7 @@ static void nss_cmd_setnetgrent_done(struct tevent_req *req) struct sss_packet *packet; uint8_t *body; size_t blen; - + struct cli_protocol *pctx; struct nss_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct nss_cmd_ctx); @@ -730,16 +739,18 @@ static void nss_cmd_setnetgrent_done(struct tevent_req *req) return; } + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* Either we succeeded or no domains were eligible */ - ret = sss_packet_new(cmdctx->cctx->creq, 0, - sss_packet_get_cmd(cmdctx->cctx->creq->in), - &cmdctx->cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret == EOK) { if (reqret == ENOENT) { /* Notify the caller that this entry wasn't found */ - sss_cmd_empty_packet(cmdctx->cctx->creq->out); + sss_cmd_empty_packet(pctx->creq->out); } else { - packet = cmdctx->cctx->creq->out; + packet = pctx->creq->out; ret = sss_packet_grow(packet, 2*sizeof(uint32_t)); if (ret != EOK) { DEBUG(SSSDBG_CRIT_FAILURE, "Couldn't grow the packet\n"); @@ -769,6 +780,7 @@ int nss_cmd_getnetgrent(struct cli_ctx *client) { errno_t ret; struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; struct nss_cmd_ctx *cmdctx; struct getent_ctx *netgr; struct tevent_req *req; @@ -782,8 +794,9 @@ int nss_cmd_getnetgrent(struct cli_ctx *client) cmdctx->cctx = client; nctx = talloc_get_type(client->rctx->pvt_ctx, struct nss_ctx); + state_ctx = talloc_get_type(client->state_ctx, struct nss_state_ctx); - if (!client->netgr_name) { + if (!state_ctx->netgr_name) { /* Tried to run getnetgrent without a preceding * setnetgrent. There is no way to determine which * netgroup is being requested. @@ -792,13 +805,13 @@ int nss_cmd_getnetgrent(struct cli_ctx *client) } /* Look up the results from the hash */ - ret = get_netgroup_entry(nctx, client->netgr_name, &netgr); + ret = get_netgroup_entry(nctx, state_ctx->netgr_name, &netgr); if (ret == ENOENT) { /* We need to invoke an implicit setnetgrent() to * wait for the result object to become available. */ - req = setnetgrent_send(cmdctx, client->netgr_name, cmdctx); + req = setnetgrent_send(cmdctx, state_ctx->netgr_name, cmdctx); if (!req) { return nss_cmd_done(cmdctx, EIO); } @@ -817,7 +830,7 @@ int nss_cmd_getnetgrent(struct cli_ctx *client) /* We need to invoke an implicit setnetgrent() to * wait for the result object to become available. */ - req = setnetgrent_send(cmdctx, client->netgr_name, cmdctx); + req = setnetgrent_send(cmdctx, state_ctx->netgr_name, cmdctx); if (!req) { return nss_cmd_done(cmdctx, EIO); } @@ -826,12 +839,12 @@ int nss_cmd_getnetgrent(struct cli_ctx *client) return EOK; } else if (!netgr->found) { DEBUG(SSSDBG_TRACE_FUNC, - "Results for [%s] not found.\n", client->netgr_name); + "Results for [%s] not found.\n", state_ctx->netgr_name); return ENOENT; } DEBUG(SSSDBG_TRACE_FUNC, - "Returning results for [%s]\n", client->netgr_name); + "Returning results for [%s]\n", state_ctx->netgr_name); /* Read the result strings */ ret = nss_cmd_getnetgrent_process(cmdctx, netgr); @@ -847,8 +860,11 @@ static void setnetgrent_implicit_done(struct tevent_req *req) struct getent_ctx *netgr; struct nss_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct nss_cmd_ctx); - struct nss_ctx *nctx = - talloc_get_type(cmdctx->cctx->rctx->pvt_ctx, struct nss_ctx); + struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; + + nctx = talloc_get_type(cmdctx->cctx->rctx->pvt_ctx, struct nss_ctx); + state_ctx = talloc_get_type(cmdctx->cctx->state_ctx, struct nss_state_ctx); ret = setnetgrent_recv(req); talloc_zfree(req); @@ -871,7 +887,7 @@ static void setnetgrent_implicit_done(struct tevent_req *req) } /* Look up the results from the hash */ - ret = get_netgroup_entry(nctx, cmdctx->cctx->netgr_name, &netgr); + ret = get_netgroup_entry(nctx, state_ctx->netgr_name, &netgr); if (ret == ENOENT) { /* Critical error. This should never happen */ DEBUG(SSSDBG_FATAL_FAILURE, @@ -908,23 +924,25 @@ static errno_t nss_cmd_retnetgrent(struct cli_ctx *client, static errno_t nss_cmd_getnetgrent_process(struct nss_cmd_ctx *cmdctx, struct getent_ctx *netgr) { - struct cli_ctx *client = cmdctx->cctx; + struct cli_protocol *pctx; uint8_t *body; size_t blen; uint32_t num; errno_t ret; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* get max num of entries to return in one call */ - sss_packet_get_body(client->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); if (blen != sizeof(uint32_t)) { return EINVAL; } SAFEALIGN_COPY_UINT32(&num, body, NULL); /* create response packet */ - ret = sss_packet_new(client->creq, 0, - sss_packet_get_cmd(client->creq->in), - &client->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } @@ -932,18 +950,18 @@ static errno_t nss_cmd_getnetgrent_process(struct nss_cmd_ctx *cmdctx, if (!netgr->entries || netgr->entries[0] == NULL) { /* No entries */ DEBUG(SSSDBG_FUNC_DATA, "No entries found\n"); - ret = sss_cmd_empty_packet(client->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); if (ret != EOK) { return nss_cmd_done(cmdctx, ret); } goto done; } - ret = nss_cmd_retnetgrent(client, netgr->entries, num); + ret = nss_cmd_retnetgrent(cmdctx->cctx, netgr->entries, num); done: - sss_packet_set_error(client->creq->out, ret); - sss_cmd_done(client, cmdctx); + sss_packet_set_error(pctx->creq->out, ret); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -960,32 +978,37 @@ static errno_t nss_cmd_retnetgrent(struct cli_ctx *client, uint8_t *body; size_t blen, rp; errno_t ret; - struct sss_packet *packet = client->creq->out; - int num, start; + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; + struct sss_packet *packet; + int num, start, cur; + + state_ctx = talloc_get_type(client->state_ctx, struct nss_state_ctx); + pctx = talloc_get_type(client->protocol_ctx, struct cli_protocol); + packet = pctx->creq->out; /* first 2 fields (len and reserved), filled up later */ rp = 2*sizeof(uint32_t); ret = sss_packet_grow(packet, rp); if (ret != EOK) return ret; - start = client->netgrent_cur; + start = cur = state_ctx->netgrent_cur; num = 0; - while (entries[client->netgrent_cur] && - (client->netgrent_cur - start) < count) { - if (entries[client->netgrent_cur]->type == SYSDB_NETGROUP_TRIPLE_VAL) { + while (entries[cur] && (cur - start) < count) { + if (entries[cur]->type == SYSDB_NETGROUP_TRIPLE_VAL) { hostlen = 1; - if (entries[client->netgrent_cur]->value.triple.hostname) { - hostlen += strlen(entries[client->netgrent_cur]->value.triple.hostname); + if (entries[cur]->value.triple.hostname) { + hostlen += strlen(entries[cur]->value.triple.hostname); } userlen = 1; - if (entries[client->netgrent_cur]->value.triple.username) { - userlen += strlen(entries[client->netgrent_cur]->value.triple.username); + if (entries[cur]->value.triple.username) { + userlen += strlen(entries[cur]->value.triple.username); } domainlen = 1; - if (entries[client->netgrent_cur]->value.triple.domainname) { - domainlen += strlen(entries[client->netgrent_cur]->value.triple.domainname); + if (entries[cur]->value.triple.domainname) { + domainlen += strlen(entries[cur]->value.triple.domainname); } len = sizeof(uint32_t) + hostlen + userlen + domainlen; @@ -1001,7 +1024,7 @@ static errno_t nss_cmd_retnetgrent(struct cli_ctx *client, body[rp] = '\0'; } else { memcpy(&body[rp], - entries[client->netgrent_cur]->value.triple.hostname, + entries[cur]->value.triple.hostname, hostlen); } rp += hostlen; @@ -1010,7 +1033,7 @@ static errno_t nss_cmd_retnetgrent(struct cli_ctx *client, body[rp] = '\0'; } else { memcpy(&body[rp], - entries[client->netgrent_cur]->value.triple.username, + entries[cur]->value.triple.username, userlen); } rp += userlen; @@ -1019,19 +1042,19 @@ static errno_t nss_cmd_retnetgrent(struct cli_ctx *client, body[rp] = '\0'; } else { memcpy(&body[rp], - entries[client->netgrent_cur]->value.triple.domainname, + entries[cur]->value.triple.domainname, domainlen); } rp += domainlen; - } else if (entries[client->netgrent_cur]->type == SYSDB_NETGROUP_GROUP_VAL) { - if (entries[client->netgrent_cur]->value.groupname == NULL || - entries[client->netgrent_cur]->value.groupname[0] == '\0') { + } else if (entries[cur]->type == SYSDB_NETGROUP_GROUP_VAL) { + if (entries[cur]->value.groupname == NULL || + entries[cur]->value.groupname[0] == '\0') { DEBUG(SSSDBG_CRIT_FAILURE, "Empty netgroup member. Please check your cache.\n"); continue; } - grouplen = 1 + strlen(entries[client->netgrent_cur]->value.groupname); + grouplen = 1 + strlen(entries[cur]->value.groupname); len = sizeof(uint32_t) + grouplen; @@ -1045,7 +1068,7 @@ static errno_t nss_cmd_retnetgrent(struct cli_ctx *client, SAFEALIGN_SET_UINT32(&body[rp], SSS_NETGR_REP_GROUP, &rp); memcpy(&body[rp], - entries[client->netgrent_cur]->value.groupname, + entries[cur]->value.groupname, grouplen); rp += grouplen; } else { @@ -1056,7 +1079,8 @@ static errno_t nss_cmd_retnetgrent(struct cli_ctx *client, } num++; - client->netgrent_cur++; + cur++; + state_ctx->netgrent_cur = cur; } sss_packet_get_body(packet, &body, &blen); @@ -1072,20 +1096,25 @@ static errno_t nss_cmd_retnetgrent(struct cli_ctx *client, int nss_cmd_endnetgrent(struct cli_ctx *client) { + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; errno_t ret; + pctx = talloc_get_type(client->protocol_ctx, struct cli_protocol); + state_ctx = talloc_get_type(client->state_ctx, struct nss_state_ctx); + /* create response packet */ - ret = sss_packet_new(client->creq, 0, - sss_packet_get_cmd(client->creq->in), - &client->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } /* Reset the indices so that subsequent requests start at zero */ - client->netgrent_cur = 0; - talloc_zfree(client->netgr_name); + state_ctx->netgrent_cur = 0; + talloc_zfree(state_ctx->netgr_name); sss_cmd_done(client, NULL); return EOK; diff --git a/src/responder/nss/nsssrv_private.h b/src/responder/nss/nsssrv_private.h index 72f7b7560..79c7b7265 100644 --- a/src/responder/nss/nsssrv_private.h +++ b/src/responder/nss/nsssrv_private.h @@ -27,6 +27,19 @@ #include <dhash.h> +struct nss_state_ent { + int dom_idx; + int cur; +}; + +struct nss_state_ctx { + struct nss_state_ent pwent; + struct nss_state_ent grent; + struct nss_state_ent svcent; + char *netgr_name; + int netgrent_cur; +}; + struct nss_cmd_ctx { struct cli_ctx *cctx; enum sss_cli_command cmd; @@ -136,4 +149,6 @@ void nss_update_initgr_memcache(struct nss_ctx *nctx, const char *name, const char *domain, int gnum, uint32_t *groups); +int nss_connection_setup(struct cli_ctx *cctx); + #endif /* NSSSRV_PRIVATE_H_ */ diff --git a/src/responder/nss/nsssrv_services.c b/src/responder/nss/nsssrv_services.c index 05f9d52fa..e1afac54e 100644 --- a/src/responder/nss/nsssrv_services.c +++ b/src/responder/nss/nsssrv_services.c @@ -790,6 +790,7 @@ nss_cmd_getserv_done(struct tevent_req *req); int nss_cmd_getservbyname(struct cli_ctx *cctx) { errno_t ret; + struct cli_protocol *pctx; struct nss_cmd_ctx *cmdctx; struct nss_dom_ctx *dctx; char *domname; @@ -811,8 +812,10 @@ int nss_cmd_getservbyname(struct cli_ctx *cctx) } dctx->cmdctx = cmdctx; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + /* get service name and protocol */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); /* if not terminated fail */ if (body[blen -1] != '\0') { ret = EINVAL; @@ -986,7 +989,7 @@ nss_cmd_getserv_done(struct tevent_req *req) { errno_t ret, reqret; unsigned int i; - + struct cli_protocol *pctx; struct nss_dom_ctx *dctx = tevent_req_callback_data(req, struct nss_dom_ctx); struct nss_cmd_ctx *cmdctx = dctx->cmdctx; @@ -1000,17 +1003,19 @@ nss_cmd_getserv_done(struct tevent_req *req) return; } + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* Either we succeeded or no domains were eligible */ - ret = sss_packet_new(cmdctx->cctx->creq, 0, - sss_packet_get_cmd(cmdctx->cctx->creq->in), - &cmdctx->cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret == EOK) { if (reqret == ENOENT) { /* Notify the caller that this entry wasn't found */ - ret = sss_cmd_empty_packet(cmdctx->cctx->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); } else { i = dctx->res->count; - ret = fill_service(cmdctx->cctx->creq->out, + ret = fill_service(pctx->creq->out, dctx->domain, dctx->protocol, dctx->res->msgs, @@ -1105,6 +1110,7 @@ done: int nss_cmd_getservbyport(struct cli_ctx *cctx) { errno_t ret; + struct cli_protocol *pctx; struct nss_cmd_ctx *cmdctx; struct nss_dom_ctx *dctx; uint16_t port; @@ -1125,8 +1131,10 @@ int nss_cmd_getservbyport(struct cli_ctx *cctx) } dctx->cmdctx = cmdctx; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + /* get service port and protocol */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); /* if not terminated fail */ if (body[blen -1] != '\0') { ret = EINVAL; @@ -1194,14 +1202,17 @@ setservent_send(TALLOC_CTX *mem_ctx, struct cli_ctx *cctx) struct setservent_ctx *state; struct sss_domain_info *dom; struct setent_step_ctx *step_ctx; - struct nss_ctx *nctx = - talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); + struct nss_ctx *nctx; + struct nss_state_ctx *state_ctx; DEBUG(SSSDBG_TRACE_FUNC, "Received setservent request\n"); + nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); + /* Reset the read pointers */ - cctx->svc_dom_idx = 0; - cctx->svcent_cur = 0; + state_ctx->svcent.dom_idx = 0; + state_ctx->svcent.cur = 0; req = tevent_req_create(mem_ctx, &state, struct setservent_ctx); if (!req) return NULL; @@ -1611,7 +1622,9 @@ nss_cmd_setservent_done(struct tevent_req *req) errno_t ret; struct nss_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct nss_cmd_ctx); + struct cli_protocol *pctx; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); ret = setservent_recv(req); talloc_zfree(req); if (ret == EOK || ret == ENOENT) { @@ -1619,9 +1632,9 @@ nss_cmd_setservent_done(struct tevent_req *req) * were eligible. * Return an acknowledgment */ - ret = sss_packet_new(cmdctx->cctx->creq, 0, - sss_packet_get_cmd(cmdctx->cctx->creq->in), - &cmdctx->cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret == EOK) { sss_cmd_done(cmdctx->cctx, cmdctx); return; @@ -1648,6 +1661,7 @@ int nss_cmd_getservent(struct cli_ctx *cctx) struct nss_ctx *nctx; struct nss_cmd_ctx *cmdctx; struct tevent_req *req; + struct nss_state_ctx *state_ctx; DEBUG(SSSDBG_TRACE_FUNC, "Requesting info for all services\n"); @@ -1663,8 +1677,9 @@ int nss_cmd_getservent(struct cli_ctx *cctx) * expired and has to be recreated, we want to resume from the same * location. */ - cmdctx->saved_dom_idx = cctx->svc_dom_idx; - cmdctx->saved_cur = cctx->svcent_cur; + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); + cmdctx->saved_dom_idx = state_ctx->svcent.dom_idx; + cmdctx->saved_cur = state_ctx->svcent.cur; nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); if(!nctx->svcctx || !nctx->svcctx->ready) { @@ -1690,6 +1705,7 @@ nss_cmd_implicit_setservent_done(struct tevent_req *req) errno_t ret; struct nss_cmd_ctx *cmdctx = tevent_req_callback_data(req, struct nss_cmd_ctx); + struct nss_state_ctx *state_ctx; ret = setservent_recv(req); talloc_zfree(req); @@ -1706,8 +1722,9 @@ nss_cmd_implicit_setservent_done(struct tevent_req *req) } /* Restore the saved index and cursor locations */ - cmdctx->cctx->svc_dom_idx = cmdctx->saved_dom_idx; - cmdctx->cctx->svcent_cur = cmdctx->saved_cur; + state_ctx = talloc_get_type(cmdctx->cctx->state_ctx, struct nss_state_ctx); + state_ctx->svcent.dom_idx = cmdctx->saved_dom_idx; + state_ctx->svcent.cur = cmdctx->saved_cur; ret = nss_cmd_getservent_immediate(cmdctx); if (ret != EOK) { @@ -1721,31 +1738,33 @@ nss_cmd_implicit_setservent_done(struct tevent_req *req) static errno_t nss_cmd_getservent_immediate(struct nss_cmd_ctx *cmdctx) { - struct cli_ctx *cctx = cmdctx->cctx; + struct cli_protocol *pctx; uint8_t *body; size_t blen; uint32_t num; int ret; + pctx = talloc_get_type(cmdctx->cctx->protocol_ctx, struct cli_protocol); + /* get max num of entries to return in one call */ - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); if (blen != sizeof(uint32_t)) { return EINVAL; } SAFEALIGN_COPY_UINT32(&num, body, NULL); /* create response packet */ - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } - ret = retservent(cctx, num); + ret = retservent(cmdctx->cctx, num); - sss_packet_set_error(cctx->creq->out, ret); - sss_cmd_done(cctx, cmdctx); + sss_packet_set_error(pctx->creq->out, ret); + sss_cmd_done(cmdctx->cctx, cmdctx); return EOK; } @@ -1753,6 +1772,8 @@ nss_cmd_getservent_immediate(struct nss_cmd_ctx *cmdctx) static errno_t retservent(struct cli_ctx *cctx, int num) { + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; struct nss_ctx *nctx; struct getent_ctx *svcctx; struct ldb_message **msgs = NULL; @@ -1760,59 +1781,65 @@ retservent(struct cli_ctx *cctx, int num) unsigned int n = 0; int ret = ENOENT; + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); if (!nctx->svcctx) goto none; svcctx = nctx->svcctx; while (ret == ENOENT) { - if (cctx->svc_dom_idx >= svcctx->num) break; + if (state_ctx->svcent.dom_idx >= svcctx->num) break; - pdom = &svcctx->doms[cctx->svc_dom_idx]; + pdom = &svcctx->doms[state_ctx->svcent.dom_idx]; - n = pdom->res->count - cctx->svcent_cur; - if (n <= 0 && (cctx->svc_dom_idx+1 < svcctx->num)) { - cctx->svc_dom_idx++; - pdom = &svcctx->doms[cctx->svc_dom_idx]; + n = pdom->res->count - state_ctx->svcent.cur; + if (n <= 0 && (state_ctx->svcent.dom_idx+1 < svcctx->num)) { + state_ctx->svcent.dom_idx++; + pdom = &svcctx->doms[state_ctx->svcent.dom_idx]; n = pdom->res->count; - cctx->svcent_cur = 0; + state_ctx->svcent.cur = 0; } if (!n) break; if (n > num) n = num; - msgs = &(pdom->res->msgs[cctx->svcent_cur]); + msgs = &(pdom->res->msgs[state_ctx->svcent.cur]); - ret = fill_service(cctx->creq->out, + ret = fill_service(pctx->creq->out, pdom->domain, NULL, msgs, &n); - cctx->svcent_cur += n; + state_ctx->svcent.cur += n; } none: if (ret == ENOENT) { - ret = sss_cmd_empty_packet(cctx->creq->out); + ret = sss_cmd_empty_packet(pctx->creq->out); } return ret; } int nss_cmd_endservent(struct cli_ctx *cctx) { + struct cli_protocol *pctx; + struct nss_state_ctx *state_ctx; struct nss_ctx *nctx; int ret; DEBUG(SSSDBG_TRACE_FUNC, "Terminating request info for all accounts\n"); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + state_ctx = talloc_get_type(cctx->state_ctx, struct nss_state_ctx); nctx = talloc_get_type(cctx->rctx->pvt_ctx, struct nss_ctx); /* create response packet */ - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; @@ -1820,8 +1847,8 @@ int nss_cmd_endservent(struct cli_ctx *cctx) if (nctx->svcctx == NULL) goto done; /* Reset the indices so that subsequent requests start at zero */ - cctx->svc_dom_idx = 0; - cctx->svcent_cur = 0; + state_ctx->svcent.dom_idx = 0; + state_ctx->svcent.cur = 0; done: sss_cmd_done(cctx, NULL); diff --git a/src/responder/pac/pacsrv.c b/src/responder/pac/pacsrv.c index 8e919780a..15d1986f8 100644 --- a/src/responder/pac/pacsrv.c +++ b/src/responder/pac/pacsrv.c @@ -123,6 +123,7 @@ int pac_process_init(TALLOC_CTX *mem_ctx, PAC_SBUS_SERVICE_VERSION, &monitor_pac_methods, "PAC", &pac_dp_methods.vtable, + sss_connection_setup, &rctx); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "sss_process_init() failed\n"); diff --git a/src/responder/pac/pacsrv_cmd.c b/src/responder/pac/pacsrv_cmd.c index 0e2b25c33..c9514cf0a 100644 --- a/src/responder/pac/pacsrv_cmd.c +++ b/src/responder/pac/pacsrv_cmd.c @@ -29,6 +29,7 @@ static errno_t pac_cmd_done(struct cli_ctx *cctx, int cmd_ret) { + struct cli_protocol *pctx; int ret; if (cmd_ret == EAGAIN) { @@ -36,15 +37,17 @@ static errno_t pac_cmd_done(struct cli_ctx *cctx, int cmd_ret) return EOK; } - ret = sss_packet_new(cctx->creq, 0, sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + + ret = sss_packet_new(pctx->creq, 0, sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { DEBUG(SSSDBG_OP_FAILURE, "sss_packet_new failed [%d][%s].\n", ret, strerror(ret)); return ret; } - sss_packet_set_error(cctx->creq->out, cmd_ret); + sss_packet_set_error(pctx->creq->out, cmd_ret); sss_cmd_done(cctx, NULL); @@ -78,8 +81,11 @@ static errno_t pac_add_pac_user(struct cli_ctx *cctx) struct pac_req_ctx *pr_ctx; struct tevent_req *req; enum idmap_error_code err; + struct cli_protocol *pctx; + + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); - sss_packet_get_body(cctx->creq->in, &body, &blen); + sss_packet_get_body(pctx->creq->in, &body, &blen); pr_ctx = talloc_zero(cctx, struct pac_req_ctx); if (pr_ctx == NULL) { diff --git a/src/responder/pam/pamsrv.c b/src/responder/pam/pamsrv.c index 7e037b403..efd1e5c75 100644 --- a/src/responder/pam/pamsrv.c +++ b/src/responder/pam/pamsrv.c @@ -202,6 +202,7 @@ static int pam_process_init(TALLOC_CTX *mem_ctx, SSS_PAM_SBUS_SERVICE_VERSION, &monitor_pam_methods, "PAM", &pam_dp_methods.vtable, + sss_connection_setup, &rctx); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "sss_process_init() failed\n"); diff --git a/src/responder/pam/pamsrv_cmd.c b/src/responder/pam/pamsrv_cmd.c index a25d2ef64..22a1872a2 100644 --- a/src/responder/pam/pamsrv_cmd.c +++ b/src/responder/pam/pamsrv_cmd.c @@ -585,6 +585,7 @@ static void pam_handle_cached_login(struct pam_auth_req *preq, int ret, static void pam_reply(struct pam_auth_req *preq) { struct cli_ctx *cctx; + struct cli_protocol *prctx; uint8_t *body; size_t blen; int ret; @@ -606,6 +607,7 @@ static void pam_reply(struct pam_auth_req *preq) pd = preq->pd; cctx = preq->cctx; pctx = talloc_get_type(preq->cctx->rctx->pvt_ctx, struct pam_ctx); + prctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); ret = confdb_get_int(pctx->rctx->cdb, CONFDB_PAM_CONF_ENTRY, CONFDB_PAM_VERBOSITY, DEFAULT_PAM_VERBOSITY, @@ -738,8 +740,8 @@ static void pam_reply(struct pam_auth_req *preq) return; } - ret = sss_packet_new(cctx->creq, 0, sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ret = sss_packet_new(prctx->creq, 0, sss_packet_get_cmd(prctx->creq->in), + &prctx->creq->out); if (ret != EOK) { goto done; } @@ -805,7 +807,7 @@ static void pam_reply(struct pam_auth_req *preq) resp = resp->next; } - ret = sss_packet_grow(cctx->creq->out, sizeof(int32_t) + + ret = sss_packet_grow(prctx->creq->out, sizeof(int32_t) + sizeof(int32_t) + resp_c * 2* sizeof(int32_t) + resp_size); @@ -813,7 +815,7 @@ static void pam_reply(struct pam_auth_req *preq) goto done; } - sss_packet_get_body(cctx->creq->out, &body, &blen); + sss_packet_get_body(prctx->creq->out, &body, &blen); DEBUG(SSSDBG_FUNC_DATA, "blen: %zu\n", blen); p = 0; @@ -928,12 +930,15 @@ static int pam_check_user_done(struct pam_auth_req *preq, int ret); static errno_t pam_forwarder_parse_data(struct cli_ctx *cctx, struct pam_data *pd) { + struct cli_protocol *prctx; uint8_t *body; size_t blen; errno_t ret; uint32_t terminator; - sss_packet_get_body(cctx->creq->in, &body, &blen); + prctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + + sss_packet_get_body(prctx->creq->in, &body, &blen); if (blen >= sizeof(uint32_t)) { SAFEALIGN_COPY_UINT32(&terminator, body + blen - sizeof(uint32_t), @@ -945,7 +950,7 @@ static errno_t pam_forwarder_parse_data(struct cli_ctx *cctx, struct pam_data *p } } - switch (cctx->cli_protocol_version->version) { + switch (prctx->cli_protocol_version->version) { case 1: ret = pam_parse_in_data(pd, body, blen); break; @@ -957,7 +962,7 @@ static errno_t pam_forwarder_parse_data(struct cli_ctx *cctx, struct pam_data *p break; default: DEBUG(SSSDBG_CRIT_FAILURE, "Illegal protocol version [%d].\n", - cctx->cli_protocol_version->version); + prctx->cli_protocol_version->version); ret = EINVAL; } if (ret != EOK) { diff --git a/src/responder/ssh/sshsrv.c b/src/responder/ssh/sshsrv.c index 2be7d4bf2..f763e3b00 100644 --- a/src/responder/ssh/sshsrv.c +++ b/src/responder/ssh/sshsrv.c @@ -97,6 +97,7 @@ int ssh_process_init(TALLOC_CTX *mem_ctx, &monitor_ssh_methods, "SSH", &ssh_dp_methods.vtable, + sss_connection_setup, &rctx); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "sss_process_init() failed\n"); diff --git a/src/responder/ssh/sshsrv_cmd.c b/src/responder/ssh/sshsrv_cmd.c index ba3b694d9..1baba8b03 100644 --- a/src/responder/ssh/sshsrv_cmd.c +++ b/src/responder/ssh/sshsrv_cmd.c @@ -662,9 +662,8 @@ done: static errno_t ssh_cmd_parse_request(struct ssh_cmd_ctx *cmd_ctx) { - struct cli_ctx *cctx = cmd_ctx->cctx; - struct ssh_ctx *ssh_ctx = talloc_get_type(cctx->rctx->pvt_ctx, - struct ssh_ctx); + struct cli_protocol *pctx; + struct ssh_ctx *ssh_ctx; errno_t ret; uint8_t *body; size_t body_len; @@ -677,7 +676,10 @@ ssh_cmd_parse_request(struct ssh_cmd_ctx *cmd_ctx) uint32_t domain_len; char *domain = NULL; - sss_packet_get_body(cctx->creq->in, &body, &body_len); + ssh_ctx = talloc_get_type(cmd_ctx->cctx->rctx->pvt_ctx, struct ssh_ctx); + pctx = talloc_get_type(cmd_ctx->cctx->protocol_ctx, struct cli_protocol); + + sss_packet_get_body(pctx->creq->in, &body, &body_len); SAFEALIGN_COPY_UINT32_CHECK(&flags, body+c, body_len, &c); if (flags & ~(uint32_t)SSS_SSH_REQ_MASK) { @@ -752,7 +754,8 @@ ssh_cmd_parse_request(struct ssh_cmd_ctx *cmd_ctx) DEBUG(SSSDBG_TRACE_FUNC, "Parsing name [%s][%s]\n", name, domain ? domain : "<ALL>"); - ret = sss_parse_name_for_domains(cmd_ctx, cctx->rctx->domains, + ret = sss_parse_name_for_domains(cmd_ctx, + cmd_ctx->cctx->rctx->domains, domain, name, &cmd_ctx->domname, &cmd_ctx->name); @@ -882,7 +885,7 @@ static errno_t decode_and_add_base64_data(struct ssh_cmd_ctx *cmd_ctx, const char *fqname, size_t *c) { - struct cli_ctx *cctx = cmd_ctx->cctx; + struct cli_protocol *pctx; uint8_t *key; size_t key_len; uint8_t *body; @@ -902,6 +905,8 @@ static errno_t decode_and_add_base64_data(struct ssh_cmd_ctx *cmd_ctx, return ENOMEM; } + pctx = talloc_get_type(cmd_ctx->cctx->protocol_ctx, struct cli_protocol); + for (d = 0; d < el->num_values; d++) { if (skip_base64_decode) { key = el->values[d].data; @@ -916,13 +921,13 @@ static errno_t decode_and_add_base64_data(struct ssh_cmd_ctx *cmd_ctx, } } - ret = sss_packet_grow(cctx->creq->out, + ret = sss_packet_grow(pctx->creq->out, 3*sizeof(uint32_t) + key_len + fqname_len); if (ret != EOK) { DEBUG(SSSDBG_OP_FAILURE, "sss_packet_grow failed.\n"); goto done; } - sss_packet_get_body(cctx->creq->out, &body, &body_len); + sss_packet_get_body(pctx->creq->out, &body, &body_len); SAFEALIGN_SET_UINT32(body+(*c), 0, c); SAFEALIGN_SET_UINT32(body+(*c), fqname_len, c); @@ -943,7 +948,6 @@ done: static errno_t ssh_cmd_build_reply(struct ssh_cmd_ctx *cmd_ctx) { - struct cli_ctx *cctx = cmd_ctx->cctx; errno_t ret; uint8_t *body; size_t body_len; @@ -957,13 +961,16 @@ ssh_cmd_build_reply(struct ssh_cmd_ctx *cmd_ctx) const char *name; char *fqname; uint32_t fqname_len; - struct ssh_ctx *ssh_ctx = talloc_get_type(cctx->rctx->pvt_ctx, - struct ssh_ctx); TALLOC_CTX *tmp_ctx; + struct ssh_ctx *ssh_ctx; + struct cli_protocol *pctx; - ret = sss_packet_new(cctx->creq, 0, - sss_packet_get_cmd(cctx->creq->in), - &cctx->creq->out); + ssh_ctx = talloc_get_type(cmd_ctx->cctx->rctx->pvt_ctx, struct ssh_ctx); + pctx = talloc_get_type(cmd_ctx->cctx->protocol_ctx, struct cli_protocol); + + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { return ret; } @@ -1007,11 +1014,11 @@ ssh_cmd_build_reply(struct ssh_cmd_ctx *cmd_ctx) } } - ret = sss_packet_grow(cctx->creq->out, 2*sizeof(uint32_t)); + ret = sss_packet_grow(pctx->creq->out, 2*sizeof(uint32_t)); if (ret != EOK) { goto done; } - sss_packet_get_body(cctx->creq->out, &body, &body_len); + sss_packet_get_body(pctx->creq->out, &body, &body_len); SAFEALIGN_SET_UINT32(body+c, count, &c); SAFEALIGN_SET_UINT32(body+c, 0, &c); @@ -1096,17 +1103,19 @@ ssh_cmd_send_error(struct ssh_cmd_ctx *cmd_ctx, static errno_t ssh_cmd_send_reply(struct ssh_cmd_ctx *cmd_ctx) { - struct cli_ctx *cctx = cmd_ctx->cctx; + struct cli_protocol *pctx; errno_t ret; + pctx = talloc_get_type(cmd_ctx->cctx->protocol_ctx, struct cli_protocol); + /* create response packet */ ret = ssh_cmd_build_reply(cmd_ctx); if (ret != EOK) { return ret; } - sss_packet_set_error(cctx->creq->out, EOK); - sss_cmd_done(cctx, cmd_ctx); + sss_packet_set_error(pctx->creq->out, EOK); + sss_cmd_done(cmd_ctx->cctx, cmd_ctx); return EOK; } diff --git a/src/responder/sudo/sudosrv.c b/src/responder/sudo/sudosrv.c index e93ec75b4..e0346033e 100644 --- a/src/responder/sudo/sudosrv.c +++ b/src/responder/sudo/sudosrv.c @@ -99,6 +99,7 @@ int sudo_process_init(TALLOC_CTX *mem_ctx, &monitor_sudo_methods, "SUDO", &sudo_dp_methods.vtable, + sss_connection_setup, &rctx); if (ret != EOK) { DEBUG(SSSDBG_FATAL_FAILURE, "sss_process_init() failed\n"); diff --git a/src/responder/sudo/sudosrv_cmd.c b/src/responder/sudo/sudosrv_cmd.c index eeb388c48..3bed22b6f 100644 --- a/src/responder/sudo/sudosrv_cmd.c +++ b/src/responder/sudo/sudosrv_cmd.c @@ -38,14 +38,17 @@ static errno_t sudosrv_cmd_send_reply(struct sudo_cmd_ctx *cmd_ctx, uint8_t *packet_body = NULL; size_t packet_len = 0; struct cli_ctx *cli_ctx = cmd_ctx->cli_ctx; + struct cli_protocol *pctx; TALLOC_CTX *tmp_ctx; tmp_ctx = talloc_new(NULL); if (!tmp_ctx) return ENOMEM; - ret = sss_packet_new(cli_ctx->creq, 0, - sss_packet_get_cmd(cli_ctx->creq->in), - &cli_ctx->creq->out); + pctx = talloc_get_type(cli_ctx->protocol_ctx, struct cli_protocol); + + ret = sss_packet_new(pctx->creq, 0, + sss_packet_get_cmd(pctx->creq->in), + &pctx->creq->out); if (ret != EOK) { DEBUG(SSSDBG_CRIT_FAILURE, "Unable to create a new packet [%d]; %s\n", @@ -53,16 +56,16 @@ static errno_t sudosrv_cmd_send_reply(struct sudo_cmd_ctx *cmd_ctx, goto done; } - ret = sss_packet_grow(cli_ctx->creq->out, response_len); + ret = sss_packet_grow(pctx->creq->out, response_len); if (ret != EOK) { DEBUG(SSSDBG_CRIT_FAILURE, "Unable to create response: %s\n", strerror(ret)); goto done; } - sss_packet_get_body(cli_ctx->creq->out, &packet_body, &packet_len); + sss_packet_get_body(pctx->creq->out, &packet_body, &packet_len); memcpy(packet_body, response_body, response_len); - sss_packet_set_error(cli_ctx->creq->out, EOK); + sss_packet_set_error(pctx->creq->out, EOK); sss_cmd_done(cmd_ctx->cli_ctx, cmd_ctx); ret = EOK; @@ -172,7 +175,8 @@ static int sudosrv_cmd(enum sss_sudo_type type, struct cli_ctx *cli_ctx) struct sudo_cmd_ctx *cmd_ctx = NULL; uint8_t *query_body = NULL; size_t query_len = 0; - uint32_t protocol = cli_ctx->cli_protocol_version->version; + struct cli_protocol *pctx; + uint32_t protocol; errno_t ret; /* create cmd_ctx */ @@ -192,6 +196,9 @@ static int sudosrv_cmd(enum sss_sudo_type type, struct cli_ctx *cli_ctx) return EFAULT; } + pctx = talloc_get_type(cli_ctx->protocol_ctx, struct cli_protocol); + protocol = pctx->cli_protocol_version->version; + /* if protocol is invalid return */ switch (protocol) { case 0: @@ -212,7 +219,7 @@ static int sudosrv_cmd(enum sss_sudo_type type, struct cli_ctx *cli_ctx) } /* parse query */ - sss_packet_get_body(cli_ctx->creq->in, &query_body, &query_len); + sss_packet_get_body(pctx->creq->in, &query_body, &query_len); if (query_len <= 0 || query_body == NULL) { DEBUG(SSSDBG_CRIT_FAILURE, "Query is empty\n"); ret = EINVAL; diff --git a/src/tests/cmocka/common_mock_resp.c b/src/tests/cmocka/common_mock_resp.c index ce73d1b45..dc03d39b6 100644 --- a/src/tests/cmocka/common_mock_resp.c +++ b/src/tests/cmocka/common_mock_resp.c @@ -63,12 +63,23 @@ mock_cctx(TALLOC_CTX *mem_ctx, struct resp_ctx *rctx) cctx = talloc_zero(mem_ctx, struct cli_ctx); if (!cctx) return NULL; - cctx->creq = talloc_zero(cctx, struct cli_request); - if (cctx->creq == NULL) { - talloc_free(cctx); + cctx->rctx = rctx; + return cctx; +} + +struct cli_protocol * +mock_prctx(TALLOC_CTX *mem_ctx) +{ + struct cli_protocol *prctx; + + prctx = talloc_zero(mem_ctx, struct cli_protocol); + if (!prctx) return NULL; + + prctx->creq = talloc_zero(prctx, struct cli_request); + if (prctx->creq == NULL) { + talloc_free(prctx); return NULL; } - cctx->rctx = rctx; - return cctx; + return prctx; } diff --git a/src/tests/cmocka/common_mock_resp.h b/src/tests/cmocka/common_mock_resp.h index a4d8f55c7..aab6a94e4 100644 --- a/src/tests/cmocka/common_mock_resp.h +++ b/src/tests/cmocka/common_mock_resp.h @@ -38,6 +38,9 @@ mock_rctx(TALLOC_CTX *mem_ctx, struct cli_ctx * mock_cctx(TALLOC_CTX *mem_ctx, struct resp_ctx *rctx); +struct cli_protocol * +mock_prctx(TALLOC_CTX *mem_ctx); + /* When mocking a module that calls sss_dp_get_account_{send,recv} * requests, your test, when linked against this module, will call * the mock functions instead. Then you can simulate results of the diff --git a/src/tests/cmocka/test_nss_srv.c b/src/tests/cmocka/test_nss_srv.c index d0b1e28e0..945e2b0c2 100644 --- a/src/tests/cmocka/test_nss_srv.c +++ b/src/tests/cmocka/test_nss_srv.c @@ -119,13 +119,17 @@ static void set_cmd_cb(cmd_cb_fn_t fn) void __wrap_sss_cmd_done(struct cli_ctx *cctx, void *freectx) { - struct sss_packet *packet = cctx->creq->out; + struct cli_protocol *pctx; + struct sss_packet *packet; uint8_t *body; size_t blen; cmd_cb_fn_t check_cb; check_cb = sss_mock_ptr_type(cmd_cb_fn_t); + pctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + packet = pctx->creq->out; + __real_sss_packet_get_body(packet, &body, &blen); nss_test_ctx->tctx->error = check_cb(sss_packet_get_status(packet), @@ -1070,6 +1074,15 @@ void test_nss_setup(struct sss_test_conf_param params[], /* Create client context */ nss_test_ctx->cctx = mock_cctx(nss_test_ctx, nss_test_ctx->rctx); assert_non_null(nss_test_ctx->cctx); + + /* Add nss specific state_ctx */ + nss_connection_setup(nss_test_ctx->cctx); + assert_non_null(nss_test_ctx->cctx->state_ctx); + + /* do after previous setup as the former nulls procotol_ctx */ + nss_test_ctx->cctx->protocol_ctx = mock_prctx(nss_test_ctx->cctx); + assert_non_null(nss_test_ctx->cctx->protocol_ctx); + } static int test_nss_getgrnam_check(struct group *expected, struct group *gr, const int nmem) diff --git a/src/tests/cmocka/test_pam_srv.c b/src/tests/cmocka/test_pam_srv.c index e4ad5b650..6f56071f8 100644 --- a/src/tests/cmocka/test_pam_srv.c +++ b/src/tests/cmocka/test_pam_srv.c @@ -217,6 +217,7 @@ void test_pam_setup(struct sss_test_conf_param dom_params[], struct sss_test_conf_param monitor_params[], void **state) { + struct cli_protocol *prctx; errno_t ret; pam_test_ctx = talloc_zero(NULL, struct pam_test_ctx); @@ -256,9 +257,12 @@ void test_pam_setup(struct sss_test_conf_param dom_params[], /* Create client context */ pam_test_ctx->cctx = mock_cctx(pam_test_ctx, pam_test_ctx->rctx); assert_non_null(pam_test_ctx->cctx); - - pam_test_ctx->cctx->cli_protocol_version = register_cli_protocol_version(); pam_test_ctx->cctx->ev = pam_test_ctx->tctx->ev; + + prctx = mock_prctx(pam_test_ctx->cctx); + assert_non_null(prctx); + pam_test_ctx->cctx->protocol_ctx = prctx; + prctx->cli_protocol_version = register_cli_protocol_version(); } static void pam_test_setup_common(void) @@ -418,11 +422,14 @@ void __real_sss_packet_get_body(struct sss_packet *packet, void __wrap_sss_cmd_done(struct cli_ctx *cctx, void *freectx) { - struct sss_packet *packet = cctx->creq->out; + struct cli_protocol *prctx; + struct sss_packet *packet; uint8_t *body; size_t blen; cmd_cb_fn_t check_cb; + prctx = talloc_get_type(cctx->protocol_ctx, struct cli_protocol); + packet = prctx->creq->out; assert_non_null(packet); check_cb = sss_mock_ptr_type(cmd_cb_fn_t); |