diff options
-rw-r--r-- | include/mscp.h | 2 | ||||
-rw-r--r-- | src/atomic.h | 56 | ||||
-rw-r--r-- | src/main.c | 4 | ||||
-rw-r--r-- | src/mscp.c | 156 | ||||
-rw-r--r-- | src/path.c | 12 | ||||
-rw-r--r-- | src/path.h | 9 | ||||
-rw-r--r-- | test/test_python.py | 6 |
7 files changed, 169 insertions, 76 deletions
diff --git a/include/mscp.h b/include/mscp.h index 3e8f80f..965af0b 100644 --- a/include/mscp.h +++ b/include/mscp.h @@ -184,7 +184,7 @@ int mscp_scan_join(struct mscp *m); * * @param m mscp instance. * - * @return 0 on success, < 0 if an error occured. + * @return number of threads on success, < 0 if an error occured. * mscp_get_error() can be used to retrieve error message. * * @see mscp_join() diff --git a/src/atomic.h b/src/atomic.h index 87ba20d..09f9f57 100644 --- a/src/atomic.h +++ b/src/atomic.h @@ -20,6 +20,8 @@ static inline refcnt refcnt_dec(refcnt *cnt) } +/* mutex */ + typedef pthread_mutex_t lock; static inline void lock_init(lock *l) @@ -44,12 +46,58 @@ static inline void lock_release_via_cleanup(void *l) lock_release(l); } -#define LOCK_ACQUIRE_THREAD(l) \ - lock_acquire(l); \ - pthread_cleanup_push(lock_release_via_cleanup, l) +#define LOCK_ACQUIRE(l) \ + lock_acquire(l); \ + pthread_cleanup_push(lock_release_via_cleanup, l) + +#define LOCK_RELEASE() \ + pthread_cleanup_pop(1) + + + +/* read/write lock */ +typedef pthread_rwlock_t rwlock; + +static inline void rwlock_init(rwlock *rw) +{ + pthread_rwlock_init(rw, NULL); +} + +static inline void rwlock_read_acquire(rwlock *rw) +{ + int ret = pthread_rwlock_rdlock(rw); + assert(ret == 0); +} + +static inline void rwlock_write_acquire(rwlock *rw) +{ + int ret = pthread_rwlock_wrlock(rw); + assert(ret == 0); +} + +static inline void rwlock_release(rwlock *rw) +{ + int ret = pthread_rwlock_unlock(rw); + assert(ret == 0); +} +static inline void rwlock_release_via_cleanup(void *rw) +{ + rwlock_release(rw); +} + +#define RWLOCK_READ_ACQUIRE(rw) \ + rwlock_read_acquire(rw); \ + pthread_cleanup_push(rwlock_release_via_cleanup, rw) + +#define RWLOCK_WRITE_ACQUIRE(rw) \ + rwlock_write_acquire(rw); \ + pthread_cleanup_push(rwlock_release_via_cleanup, rw) -#define LOCK_RELEASE_THREAD() \ + +#define RWLOCK_RELEASE() \ pthread_cleanup_pop(1) + + #endif /* _ATOMIC_H_ */ @@ -380,11 +380,11 @@ int main(int argc, char **argv) ret = mscp_start(m); if (ret < 0) - fprintf(stderr, "%s\n", mscp_get_error()); + fprintf(stderr, "mscp_start: %s\n", mscp_get_error()); ret = mscp_join(m); if (ret != 0) - fprintf(stderr, "%s\n", mscp_get_error()); + fprintf(stderr, "mscp_join: %s\n", mscp_get_error()); pthread_cancel(tid_stat); pthread_join(tid_stat, NULL); @@ -40,11 +40,15 @@ struct mscp { int ret_scan; /* return code from scan thread */ size_t total_bytes; /* total bytes to be transferred */ - struct mscp_thread *threads; + + struct list_head thread_list; + rwlock thread_rwlock; }; struct mscp_thread { + struct list_head list; /* mscp->thread_list */ + struct mscp *m; int id; sftp_session sftp; @@ -56,7 +60,7 @@ struct mscp_thread { }; struct src { - struct list_head list; + struct list_head list; /* mscp->src_list */ char *path; }; @@ -211,7 +215,7 @@ struct mscp *mscp_init(const char *remote_host, int direction, int n; if (!remote_host) { - mscp_set_error("empty remote host\n"); + mscp_set_error("empty remote host"); return NULL; } @@ -238,6 +242,9 @@ struct mscp *mscp_init(const char *remote_host, int direction, INIT_LIST_HEAD(&m->path_list); chunk_pool_init(&m->cp); + INIT_LIST_HEAD(&m->thread_list); + rwlock_init(&m->thread_rwlock); + if ((m->sem = sem_create(o->max_startups)) == NULL) { mscp_set_error("sem_create: %s", strerrno()); goto free_out; @@ -339,11 +346,14 @@ static int get_page_mask(void) static void mscp_stop_copy_thread(struct mscp *m) { - int n; - for (n = 0; n < m->opts->nr_threads; n++) { - if (m->threads[n].tid && !m->threads[n].finished) - pthread_cancel(m->threads[n].tid); - } + struct mscp_thread *t; + + RWLOCK_READ_ACQUIRE(&m->thread_rwlock); + list_for_each_entry(t, &m->thread_list, list) { + if (!t->finished) + pthread_cancel(t->tid); + } + RWLOCK_RELEASE(); } static void mscp_stop_scan_thread(struct mscp *m) @@ -448,10 +458,10 @@ int mscp_scan(struct mscp *m) return -1; } - /* need scan finished or over nr_threads chunks to determine - * actual number of threads (and connections). If the number - * of chunks are smaller than nr_threads, we adjust nr_threads - * to the number of chunks. + /* We wait for there are over nr_threads chunks to determine + * actual number of threads (and connections), or scan + * finished. If the number of chunks are smaller than + * nr_threads, we adjust nr_threads to the number of chunks. */ while (!chunk_pool_is_filled(&m->cp) && chunk_pool_size(&m->cp) < m->opts->nr_threads) @@ -474,9 +484,40 @@ int mscp_scan_join(struct mscp *m) static void *mscp_copy_thread(void *arg); +static struct mscp_thread *mscp_copy_thread_spawn(struct mscp *m, int id) +{ + struct mscp_thread *t; + int ret; + + t = malloc(sizeof(*t)); + if (!t){ + mscp_set_error("malloc: %s,", strerrno()); + return NULL; + } + + memset(t, 0, sizeof(*t)); + t->m = m; + t->id = id; + if (m->cores == NULL) + t->cpu = -1; /* not pinned to cpu */ + else + t->cpu = m->cores[id % m->nr_cores]; + + ret = pthread_create(&t->tid, NULL, mscp_copy_thread, t); + if (ret < 0) { + mscp_set_error("pthread_create error: %d", ret); + free(t); + return NULL; + } + + return t; +} + + int mscp_start(struct mscp *m) { - int n, ret; + struct mscp_thread *t; + int n, ret = 0; if ((n = chunk_pool_size(&m->cp)) < m->opts->nr_threads) { mpr_notice(m->msg_fp, "we have only %d chunk(s). " @@ -484,63 +525,46 @@ int mscp_start(struct mscp *m) m->opts->nr_threads = n; } - /* scan thread instances */ - m->threads = calloc(m->opts->nr_threads, sizeof(struct mscp_thread)); - memset(m->threads, 0, m->opts->nr_threads * sizeof(struct mscp_thread)); for (n = 0; n < m->opts->nr_threads; n++) { - struct mscp_thread *t = &m->threads[n]; - t->m = m; - t->id = n; - if (!m->cores) - t->cpu = -1; - else - t->cpu = m->cores[n % m->nr_cores]; - - ret = pthread_create(&t->tid, NULL, mscp_copy_thread, t); - if (ret < 0) { - mscp_set_error("pthread_create error: %d", ret); - mscp_stop(m); - return -1; - } + t = mscp_copy_thread_spawn(m, n); + if (!t) { + mpr_err(m->msg_fp, "failed to spawn copy thread\n"); + break; + } + RWLOCK_WRITE_ACQUIRE(&m->thread_rwlock); + list_add_tail(&t->list, &m->thread_list); + RWLOCK_RELEASE(); } - return 0; + return n; } int mscp_join(struct mscp *m) { + struct mscp_thread *t; int n, ret = 0; /* waiting for scan thread joins... */ ret = mscp_scan_join(m); /* waiting for copy threads join... */ - for (n = 0; n < m->opts->nr_threads; n++) { - if (m->threads[n].tid) { - pthread_join(m->threads[n].tid, NULL); - if (m->threads[n].ret < 0) - ret = m->threads[n].ret; - } - } + RWLOCK_READ_ACQUIRE(&m->thread_rwlock); + list_for_each_entry(t, &m->thread_list, list) { + pthread_join(t->tid, NULL); + if (t->ret < 0) + ret = t->ret; + if (t->sftp) { + ssh_sftp_close(t->sftp); + t->sftp = NULL; + } + } + RWLOCK_RELEASE(); if (m->first) { ssh_sftp_close(m->first); m->first = NULL; } - if (m->threads) { - for (n = 0; n < m->opts->nr_threads; n++) { - struct mscp_thread *t = &m->threads[n]; - if (t->ret != 0) - ret = ret; - - if (t->sftp) { - ssh_sftp_close(t->sftp); - t->sftp = NULL; - } - } - } - return ret; } @@ -567,7 +591,7 @@ void *mscp_copy_thread(void *arg) } if (sem_wait(m->sem) < 0) { - mscp_set_error("sem_wait: %s\n", strerrno()); + mscp_set_error("sem_wait: %s", strerrno()); mpr_err(m->msg_fp, "%s", mscp_get_error()); goto err_out; } @@ -577,7 +601,7 @@ void *mscp_copy_thread(void *arg) t->sftp = ssh_init_sftp_session(m->remote, m->ssh_opts); if (sem_post(m->sem) < 0) { - mscp_set_error("sem_post: %s\n", strerrno()); + mscp_set_error("sem_post: %s", strerrno()); mpr_err(m->msg_fp, "%s", mscp_get_error()); goto err_out; } @@ -629,6 +653,7 @@ void *mscp_copy_thread(void *arg) return NULL; err_out: + t->finished = true; t->ret = -1; return NULL; } @@ -658,6 +683,13 @@ static void free_chunk(struct list_head *list) free(c); } +static void free_thread(struct list_head *list) +{ + struct mscp_thread *t; + t = list_entry(list, typeof(*t), list); + free(t); +} + void mscp_cleanup(struct mscp *m) { if (m->first) { @@ -674,10 +706,9 @@ void mscp_cleanup(struct mscp *m) chunk_pool_release(&m->cp); chunk_pool_init(&m->cp); - if (m->threads) { - free(m->threads); - m->threads = NULL; - } + RWLOCK_WRITE_ACQUIRE(&m->thread_rwlock); + list_free_f(&m->thread_list, free_thread); + RWLOCK_RELEASE(); } void mscp_free(struct mscp *m) @@ -694,16 +725,19 @@ void mscp_free(struct mscp *m) void mscp_get_stats(struct mscp *m, struct mscp_stats *s) { + struct mscp_thread *t; bool finished = true; - int n; s->total = m->total_bytes; - for (s->done = 0, n = 0; n < m->opts->nr_threads; n++) { - s->done += m->threads[n].done; + s->done = 0; - if (!m->threads[n].done) + RWLOCK_READ_ACQUIRE(&m->thread_rwlock); + list_for_each_entry(t, &m->thread_list, list) { + s->done += t->done; + if (!t->finished) finished = false; } + RWLOCK_RELEASE(); s->finished = finished; } @@ -27,10 +27,10 @@ void chunk_pool_init(struct chunk_pool *cp) static void chunk_pool_add(struct chunk_pool *cp, struct chunk *c) { - LOCK_ACQUIRE_THREAD(&cp->lock); + LOCK_ACQUIRE(&cp->lock); list_add_tail(&c->list, &cp->list); cp->count += 1; - LOCK_RELEASE_THREAD(); + LOCK_RELEASE(); } void chunk_pool_set_filled(struct chunk_pool *cp) @@ -54,7 +54,7 @@ struct chunk *chunk_pool_pop(struct chunk_pool *cp) struct list_head *first; struct chunk *c = NULL; - LOCK_ACQUIRE_THREAD(&cp->lock); + LOCK_ACQUIRE(&cp->lock); first = cp->list.next; if (list_empty(&cp->list)) { if (!chunk_pool_is_filled(cp)) @@ -65,7 +65,7 @@ struct chunk *chunk_pool_pop(struct chunk_pool *cp) c = list_entry(first, struct chunk, list); list_del(first); } - LOCK_RELEASE_THREAD(); + LOCK_RELEASE(); /* return CHUNK_POP_WAIT would be very rare case, because it * means copying over SSH is faster than traversing @@ -363,7 +363,7 @@ static int prepare_dst_path(FILE *msg_fp, struct path *p, sftp_session dst_sftp) { int ret = 0; - LOCK_ACQUIRE_THREAD(&p->lock); + LOCK_ACQUIRE(&p->lock); if (p->state == FILE_STATE_INIT) { if (touch_dst_path(p, dst_sftp) < 0) { ret = -1; @@ -374,7 +374,7 @@ static int prepare_dst_path(FILE *msg_fp, struct path *p, sftp_session dst_sftp) } out: - LOCK_RELEASE_THREAD(); + LOCK_RELEASE(); return ret; } @@ -201,11 +201,16 @@ static int mscp_stat(const char *path, mstat *s, sftp_session sftp) if (sftp) { s->r = sftp_stat(sftp, path); - if (!s->r) + if (!s->r) { + mscp_set_error("sftp_stat: %s %s", + sftp_get_ssh_error(sftp), path); return -1; + } } else { - if (stat(path, &s->l) < 0) + if (stat(path, &s->l) < 0) { + mscp_set_error("stat: %s %s", strerrno(), path); return -1; + } } return 0; diff --git a/test/test_python.py b/test/test_python.py index ea2d278..a6b2787 100644 --- a/test/test_python.py +++ b/test/test_python.py @@ -104,6 +104,12 @@ def test_login_failed(): with pytest.raises(RuntimeError) as e: m.connect() +def test_get_stat_before_copy_start(): + m = mscp.mscp("localhost", mscp.LOCAL2REMOTE) + m.connect() + (total, done, finished) = m.stats() + assert total == 0 and done == 0 + param_invalid_kwargs = [ { "nr_threads": -1 }, |