diff --git a/include/libssh/buffer.h b/include/libssh/buffer.h index a55a1b40..e34e075c 100644 --- a/include/libssh/buffer.h +++ b/include/libssh/buffer.h @@ -33,6 +33,8 @@ int ssh_buffer_add_u8(ssh_buffer buffer, uint8_t data); int ssh_buffer_add_u16(ssh_buffer buffer, uint16_t data); int ssh_buffer_add_u32(ssh_buffer buffer, uint32_t data); int ssh_buffer_add_u64(ssh_buffer buffer, uint64_t data); +ssize_t ssh_buffer_add_func(ssh_buffer buffer, ssh_add_func f, size_t max_bytes, + void *userdata); int ssh_buffer_validate_length(struct ssh_buffer_struct *buffer, size_t len); diff --git a/include/libssh/libssh.h b/include/libssh/libssh.h index 7857a77b..3eef7a16 100644 --- a/include/libssh/libssh.h +++ b/include/libssh/libssh.h @@ -833,6 +833,7 @@ LIBSSH_API const char* ssh_get_hmac_in(ssh_session session); LIBSSH_API const char* ssh_get_hmac_out(ssh_session session); LIBSSH_API ssh_buffer ssh_buffer_new(void); +LIBSSH_API ssh_buffer ssh_buffer_new_size(uint32_t size, uint32_t headroom); LIBSSH_API void ssh_buffer_free(ssh_buffer buffer); #define SSH_BUFFER_FREE(x) \ do { if ((x) != NULL) { ssh_buffer_free(x); x = NULL; } } while(0) @@ -843,6 +844,8 @@ LIBSSH_API void *ssh_buffer_get(ssh_buffer buffer); LIBSSH_API uint32_t ssh_buffer_get_len(ssh_buffer buffer); LIBSSH_API int ssh_session_set_disconnect_message(ssh_session session, const char *message); +typedef ssize_t (*ssh_add_func) (void *ptr, size_t max_bytes, void *userdata); + #ifndef LIBSSH_LEGACY_0_4 #include "libssh/legacy.h" #endif diff --git a/include/libssh/sftp.h b/include/libssh/sftp.h index c855df8a..0fcdb9b8 100644 --- a/include/libssh/sftp.h +++ b/include/libssh/sftp.h @@ -565,6 +565,10 @@ LIBSSH_API int sftp_async_read(sftp_file file, void *data, uint32_t len, uint32_ */ LIBSSH_API ssize_t sftp_write(sftp_file file, const void *buf, size_t count); +LIBSSH_API ssize_t sftp_async_write(sftp_file file, ssh_add_func f, size_t count, + void *userdata, uint32_t* id); +LIBSSH_API int sftp_async_write_end(sftp_file file, uint32_t id, int blocking); + /** * @brief Seek to a specific location in a file. * diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index c090fef7..e2f86309 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -435,6 +435,11 @@ if (BUILD_STATIC_LIB) if (WIN32) target_compile_definitions(ssh-static PUBLIC "LIBSSH_STATIC") endif (WIN32) + + install(TARGETS ssh-static + EXPORT libssh-config + LIBRARY DESTINATION ${CMAKE_INSTALL_LIBDIR} + COMPONENT libraries) endif (BUILD_STATIC_LIB) message(STATUS "Threads_FOUND=${Threads_FOUND}") diff --git a/src/buffer.c b/src/buffer.c index e0068015..cc0caf35 100644 --- a/src/buffer.c +++ b/src/buffer.c @@ -141,6 +141,40 @@ struct ssh_buffer_struct *ssh_buffer_new(void) return buf; } +/** + * @brief Create a new SSH buffer with a specified size and headroom. + * + * @param[in] len length for newly initialized SSH buffer. + * @param[in] headroom length for headroom + * @return A newly initialized SSH buffer, NULL on error. + */ +struct ssh_buffer_struct *ssh_buffer_new_size(uint32_t len, uint32_t headroom) +{ + struct ssh_buffer_struct *buf = NULL; + int rc; + + if (len < headroom) + return NULL; + + buf = calloc(1, sizeof(struct ssh_buffer_struct)); + if (buf == NULL) { + return NULL; + } + + rc = ssh_buffer_allocate_size(buf, len); + if (rc != 0) { + SAFE_FREE(buf); + return NULL; + } + + buf->pos += headroom; + buf->used += headroom; + + buffer_verify(buf); + + return buf; +} + /** * @brief Deallocate a SSH buffer. * @@ -328,6 +362,49 @@ int ssh_buffer_add_data(struct ssh_buffer_struct *buffer, const void *data, uint return 0; } +/** + * @brief Add data at the tail of a buffer by an external function + * + * @param[in] buffer The buffer to add data. + * + * @param[in] f function that adds data to the buffer. + * + * @param[in] max_bytes The maximum length of the data to add. + * + * @return actual bytes added on success, < 0 on error. + */ +ssize_t ssh_buffer_add_func(struct ssh_buffer_struct *buffer, ssh_add_func f, + size_t max_bytes, void *userdata) +{ + ssize_t actual; + + if (buffer == NULL) { + return -1; + } + + buffer_verify(buffer); + + if (buffer->used + max_bytes < max_bytes) { + return -1; + } + + if (buffer->allocated < (buffer->used + max_bytes)) { + if (buffer->pos > 0) { + buffer_shift(buffer); + } + if (realloc_buffer(buffer, buffer->used + max_bytes) < 0) { + return -1; + } + } + + if ((actual = f(buffer->data + buffer->used, max_bytes, userdata)) < 0) + return -1; + + buffer->used += actual; + buffer_verify(buffer); + return actual; +} + /** * @brief Ensure the buffer has at least a certain preallocated size. * diff --git a/src/sftp.c b/src/sftp.c index e01012a8..702623a0 100644 --- a/src/sftp.c +++ b/src/sftp.c @@ -2228,6 +2228,132 @@ ssize_t sftp_write(sftp_file file, const void *buf, size_t count) { return -1; /* not reached */ } +/* + * sftp_async_write is based on and sftp_async_write_end is copied from + * https://github.com/limes-datentechnik-gmbh/libssh + * + * sftp_async_write has some optimizations: + * - use ssh_buffer_new_size() to reduce realoc_buffer. + * - use ssh_buffer_add_func() to avoid memcpy from read buffer to ssh buffer. + */ +ssize_t sftp_async_write(sftp_file file, ssh_add_func f, size_t count, void *userdata, + uint32_t* id) { + sftp_session sftp = file->sftp; + ssh_buffer buffer; + uint32_t buf_sz; + ssize_t actual; + int len; + int packetlen; + int rc; + +#define HEADROOM 16 + /* sftp_packet_write() prepends a 5-bytes (uint32_t length and + * 1-byte type) header to the head of the payload by + * ssh_buffer_prepend_data(). Inserting headroom by + * ssh_buffer_new_size() eliminates memcpy for prepending the + * header. + */ + + buf_sz = (HEADROOM + /* for header */ + sizeof(uint32_t) + /* id */ + ssh_string_len(file->handle) + 4 + /* file->handle */ + sizeof(uint64_t) + /* file->offset */ + sizeof(uint32_t) + /* count */ + count); /* datastring */ + + buffer = ssh_buffer_new_size(buf_sz, HEADROOM); + if (buffer == NULL) { + ssh_set_error_oom(sftp->session); + return -1; + } + + *id = sftp_get_new_id(file->sftp); + + rc = ssh_buffer_pack(buffer, + "dSqd", + *id, + file->handle, + file->offset, + count); /* len of datastring */ + + if (rc != SSH_OK){ + ssh_set_error_oom(sftp->session); + ssh_buffer_free(buffer); + return SSH_ERROR; + } + + actual = ssh_buffer_add_func(buffer, f, count, userdata); + if (actual < 0){ + ssh_set_error_oom(sftp->session); + ssh_buffer_free(buffer); + return SSH_ERROR; + } + + packetlen=ssh_buffer_get_len(buffer)+5; + len = sftp_packet_write(file->sftp, SSH_FXP_WRITE, buffer); + ssh_buffer_free(buffer); + if (len < 0) { + return SSH_ERROR; + } else if (len != packetlen) { + ssh_set_error(sftp->session, SSH_FATAL, + "Could only send %d of %d bytes to remote host!", len, packetlen); + SSH_LOG(SSH_LOG_PACKET, + "Could not write as much data as expected"); + return SSH_ERROR; + } + + file->offset += actual; + + return actual; +} + +int sftp_async_write_end(sftp_file file, uint32_t id, int blocking) { + sftp_session sftp = file->sftp; + sftp_message msg = NULL; + sftp_status_message status; + + msg = sftp_dequeue(sftp, id); + while (msg == NULL) { + if (!blocking && ssh_channel_poll(sftp->channel, 0) == 0) { + /* we cannot block */ + return SSH_AGAIN; + } + if (sftp_read_and_dispatch(sftp) < 0) { + /* something nasty has happened */ + return SSH_ERROR; + } + msg = sftp_dequeue(sftp, id); + } + + switch (msg->packet_type) { + case SSH_FXP_STATUS: + status = parse_status_msg(msg); + sftp_message_free(msg); + if (status == NULL) { + return SSH_ERROR; + } + sftp_set_error(sftp, status->status); + switch (status->status) { + case SSH_FX_OK: + status_msg_free(status); + return SSH_OK; + default: + break; + } + ssh_set_error(sftp->session, SSH_REQUEST_DENIED, + "SFTP server: %s", status->errormsg); + status_msg_free(status); + return SSH_ERROR; + default: + ssh_set_error(sftp->session, SSH_FATAL, + "Received message %d during write!", msg->packet_type); + sftp_message_free(msg); + return SSH_ERROR; + } + + return SSH_ERROR; /* not reached */ +} + /* Seek to a specific location in a file. */ int sftp_seek(sftp_file file, uint32_t new_offset) { if (file == NULL) {