diff --git a/libmemcached/sasl.cc b/libmemcached/sasl.cc index 3aefb701..bd67e06d 100644 --- a/libmemcached/sasl.cc +++ b/libmemcached/sasl.cc @@ -43,6 +43,8 @@ #include #include +#define ASCII_SASL + void memcached_set_sasl_callbacks(memcached_st *ptr, const sasl_callback_t *callbacks) { @@ -120,7 +122,11 @@ static void sasl_startup_function(void) } // extern "C" +#ifdef ASCII_SASL +static memcached_return_t memcached_sasl_authenticate_connection_binary(memcached_server_st *server) +#else memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *server) +#endif { if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) { @@ -270,6 +276,146 @@ memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *s return memcached_set_error(*server, rc, MEMCACHED_AT); } +#ifdef ASCII_SASL +static memcached_return_t memcached_sasl_authenticate_connection_ascii(memcached_server_st *server) +{ + char req_hdr[1024]; + unsigned int hdrlen; + char buffer[MEMCACHED_DEFAULT_COMMAND_SIZE + MEMCACHED_MAX_KEY]; + + hdrlen = snprintf(req_hdr, 1024, "sasl mech\r\n"); + assert(hdrlen < 1024); + + if (memcached_io_write(server, req_hdr, hdrlen, 1) != hdrlen) { + return MEMCACHED_WRITE_FAILURE; + } + + memcached_server_response_increment(server); + + memcached_return_t rc = memcached_response(server, buffer, sizeof(buffer), NULL); + assert(memcached_response(server, buffer, sizeof(buffer), NULL) == MEMCACHED_END); + if (memcached_failed(rc)) { + if (rc == MEMCACHED_PROTOCOL_ERROR) { + /* If the server doesn't support SASL it will return PROTOCOL_ERROR. + * This error may also be returned for other errors, but let's assume + * that the server don't support SASL and treat it as success and + * let the client fail with the next operation if the error was + * caused by another problem.... + */ + rc= MEMCACHED_SUCCESS; + } + return rc; + } + + /* set ip addresses */ + char laddr[NI_MAXHOST + NI_MAXSERV]; + char raddr[NI_MAXHOST + NI_MAXSERV]; + + if (memcached_failed(rc = resolve_names(*server, laddr, sizeof(laddr), raddr, sizeof(raddr)))) { + return rc; + } + + int pthread_error; + if ((pthread_error = pthread_once(&sasl_startup_once, sasl_startup_function)) != 0) { + return memcached_set_errno(*server, pthread_error, MEMCACHED_AT); + } + + (void)pthread_mutex_lock(&sasl_startup_state_LOCK); + if (sasl_startup_state != SASL_OK) { + const char *sasl_error_msg= sasl_errstring(sasl_startup_state, NULL, NULL); + return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, + memcached_string_make_from_cstr(sasl_error_msg)); + } + (void)pthread_mutex_unlock(&sasl_startup_state_LOCK); + + sasl_conn_t *conn; + int ret; + if ((ret= sasl_client_new("memcached", server->hostname, laddr, raddr, server->root->sasl.callbacks, 0, &conn) ) != SASL_OK) { + const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL); + + sasl_dispose(&conn); + + return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, + memcached_string_make_from_cstr(sasl_error_msg)); + } + + const char *data; + const char *chosenmech; + unsigned int len; + ret= sasl_client_start(conn, memcached_result_value(&server->root->result), NULL, &data, &len, &chosenmech); + if (ret != SASL_OK and ret != SASL_CONTINUE) + { + const char *sasl_error_msg= sasl_errstring(ret, NULL, NULL); + + sasl_dispose(&conn); + + return memcached_set_error(*server, MEMCACHED_AUTH_PROBLEM, MEMCACHED_AT, + memcached_string_make_from_cstr(sasl_error_msg)); + } + + hdrlen = snprintf(req_hdr, 1024, "sasl auth %s %u\r\n", chosenmech, len); + + do { + /* send the packet */ + + struct libmemcached_io_vector_st vector[] = { + { hdrlen, req_hdr }, + { len, data }, + { 2, "\r\n" }, + }; + + if (memcached_io_writev(server, vector, 3, true) == -1) { + rc = MEMCACHED_WRITE_FAILURE; + break; + } + + memcached_server_response_increment(server); + + rc = memcached_response(server, buffer, sizeof(buffer), NULL); + if (rc == MEMCACHED_SUCCESS) { + assert(memcached_response(server, buffer, sizeof(buffer), NULL) == MEMCACHED_END); + } else { + break; + } + + ret = sasl_client_step(conn, memcached_result_value(&server->root->result), + memcached_result_length(&server->root->result), + NULL, &data, &len); + if (ret != SASL_OK && ret != SASL_CONTINUE) { + rc= MEMCACHED_AUTH_PROBLEM; + break; + } + + hdrlen = snprintf(req_hdr, 1024, "sasl auth %u\r\n", len); + } while (true); + + /* Release resources */ + sasl_dispose(&conn); + + return memcached_set_error(*server, rc, MEMCACHED_AT); +} + +memcached_return_t memcached_sasl_authenticate_connection(memcached_server_st *server) +{ + if (LIBMEMCACHED_WITH_SASL_SUPPORT == 0) + { + return MEMCACHED_NOT_SUPPORTED; + } + + if (server == NULL) + { + return MEMCACHED_INVALID_ARGUMENTS; + } + + /* SANITY CHECK: SASL can only be used with the binary protocol */ + if (server->root->flags.binary_protocol) { + return memcached_sasl_authenticate_connection_binary(server); + } else { + return memcached_sasl_authenticate_connection_ascii(server); + } +} +#endif + static int get_username(void *context, int id, const char **result, unsigned int *len) { if (!context || !result || (id != SASL_CB_USER && id != SASL_CB_AUTHNAME)) @@ -313,11 +459,13 @@ memcached_return_t memcached_set_sasl_auth_data(memcached_st *ptr, return MEMCACHED_INVALID_ARGUMENTS; } +#ifndef ASCII_SASL memcached_return_t ret; if (memcached_failed(ret= memcached_behavior_set(ptr, MEMCACHED_BEHAVIOR_BINARY_PROTOCOL, 1))) { return memcached_set_error(*ptr, ret, MEMCACHED_AT, memcached_literal_param("Unable change to binary protocol which is required for SASL.")); } +#endif memcached_destroy_sasl_auth_data(ptr);