diff --git a/client.c b/client.c index 7d1e346..847afac 100644 --- a/client.c +++ b/client.c @@ -1427,14 +1427,9 @@ submit_request(CMD_Request *request, CMD_Reply *reply) DEBUG_LOG("Received %d bytes", recv_status); read_length = recv_status; - if (read_length >= offsetof(CMD_Reply, data)) { - expected_length = PKL_ReplyLength(reply); - } else { - expected_length = 0; - } + expected_length = PKL_ReplyLength(reply, read_length); - bad_length = (read_length < expected_length || - expected_length < offsetof(CMD_Reply, data)); + bad_length = !expected_length || read_length < expected_length; if (!bad_length) { bad_sequence = reply->sequence != request->sequence; diff --git a/cmdmon.c b/cmdmon.c index 4ed2189..283ba97 100644 --- a/cmdmon.c +++ b/cmdmon.c @@ -279,7 +279,7 @@ do_size_checks(void) reply.reply = htons(i); reply.status = STT_SUCCESS; reply.data.manual_list.n_samples = htonl(MAX_MANUAL_LIST_SAMPLES); - reply_length = PKL_ReplyLength(&reply); + reply_length = PKL_ReplyLength(&reply, sizeof (reply)); if ((reply_length && reply_length < offsetof(CMD_Reply, data)) || reply_length > sizeof (CMD_Reply)) assert(0); @@ -393,7 +393,7 @@ transmit_reply(CMD_Reply *msg, union sockaddr_all *where_to) assert(0); } - tx_message_length = PKL_ReplyLength(msg); + tx_message_length = PKL_ReplyLength(msg, sizeof (*msg)); status = sendto(sock_fd, (void *) msg, tx_message_length, 0, &where_to->sa, addrlen); diff --git a/pktlength.c b/pktlength.c index 23a1b47..d93d139 100644 --- a/pktlength.c +++ b/pktlength.c @@ -183,12 +183,15 @@ PKL_CommandPaddingLength(CMD_Request *r) /* ================================================== */ int -PKL_ReplyLength(CMD_Reply *r) +PKL_ReplyLength(CMD_Reply *r, int read_length) { uint32_t type; assert(sizeof (reply_lengths) / sizeof (reply_lengths[0]) == N_REPLY_TYPES); + if (read_length < (int)offsetof(CMD_Reply, data)) + return 0; + type = ntohs(r->reply); /* Note that reply type codes start from 1, not 0 */ @@ -202,6 +205,9 @@ PKL_ReplyLength(CMD_Reply *r) if (r->status != htons(STT_SUCCESS)) return offsetof(CMD_Reply, data); + if (read_length < (int)offsetof(CMD_Reply, data.manual_list.samples)) + return 0; + ns = ntohl(r->data.manual_list.n_samples); if (ns > MAX_MANUAL_LIST_SAMPLES) return 0; diff --git a/pktlength.h b/pktlength.h index fad4c30..40646fe 100644 --- a/pktlength.h +++ b/pktlength.h @@ -35,6 +35,6 @@ extern int PKL_CommandLength(CMD_Request *r); extern int PKL_CommandPaddingLength(CMD_Request *r); -extern int PKL_ReplyLength(CMD_Reply *r); +extern int PKL_ReplyLength(CMD_Reply *r, int read_length); #endif /* GOT_PKTLENGTH_H */