diff --git a/include/downloader.h b/include/downloader.h index 4cd8337ad..f385823c9 100644 --- a/include/downloader.h +++ b/include/downloader.h @@ -25,6 +25,7 @@ #include #include #include +#include namespace kiwix { @@ -43,6 +44,14 @@ class AriaError : public std::runtime_error { }; +/** + * A representation of a current download. + * + * `Download` is not thread safe. User must care to not call method on a + * same download from different threads. + * However, it is safe to use different `Download`s from different threads. + */ + class Download { public: typedef enum { K_ACTIVE, K_WAITING, K_PAUSED, K_ERROR, K_COMPLETE, K_REMOVED, K_UNKNOWN } StatusResult; @@ -53,19 +62,89 @@ class Download { : mp_aria(p_aria), m_status(K_UNKNOWN), m_did(did) {}; - void updateStatus(bool follow=false); + + /** + * Update the status of the download. + * + * This call make an aria rpc call and is blocking. + * Some download (started with a metalink) are in fact several downloads. + * - A first one to download the metadlink. + * - A second one to download the real file. + * + * If `follow` is true, updateStatus tries to detect that and tracks + * the second download when the first one is finished. + * By passing false to `follow`, `Download` will only track the first download. + * + * `getFoo` methods are based on the last statusUpdate. + * + * @param follow: Do we have to follow following downloads. + */ + void updateStatus(bool follow); + + /** + * Pause the download (and call updateStatus) + */ void pauseDownload(); + + /** + * Resume the download (and call updateStatus) + */ void resumeDownload(); + + /** + * Cancel the download. + * + * A canceled downlod cannot be resume and updateStatus does nothing. + * However, you can still get information based on the last known information. + */ void cancelDownload(); - StatusResult getStatus() { return m_status; } - std::string getDid() { return m_did; } - std::string getFollowedBy() { return m_followedBy; } - uint64_t getTotalLength() { return m_totalLength; } - uint64_t getCompletedLength() { return m_completedLength; } - uint64_t getDownloadSpeed() { return m_downloadSpeed; } - uint64_t getVerifiedLength() { return m_verifiedLength; } - std::string getPath() { return m_path; } - std::vector& getUris() { return m_uris; } + + /* + * Get the status of the download. + */ + StatusResult getStatus() const { return m_status; } + + /* + * Get the id of the download. + */ + const std::string& getDid() const { return m_did; } + + /* + * Get the id of the "second" download. + * + * Set only if the "first" download is a metalink and is complete. + */ + const std::string& getFollowedBy() const { return m_followedBy; } + + /* + * Get the total length of the download. + */ + uint64_t getTotalLength() const { return m_totalLength; } + + /* + * Get the completed length of the download. + */ + uint64_t getCompletedLength() const { return m_completedLength; } + + /* + * Get the download speed of the download. + */ + uint64_t getDownloadSpeed() const { return m_downloadSpeed; } + + /* + * Get the verified length of the download. + */ + uint64_t getVerifiedLength() const { return m_verifiedLength; } + + /* + * Get the path (local file) of the download. + */ + const std::string& getPath() const { return m_path; } + + /* + * Get the download uris of the download. + */ + const std::vector& getUris() const { return m_uris; } protected: std::shared_ptr mp_aria; @@ -83,6 +162,9 @@ class Download { /** * A tool to download things. * + * A Downloader manages `Download` using aria2 in the background. + * `Downloader` is threadsafe. + * However, the returned `Download`s are NOT threadsafe. */ class Downloader { @@ -92,14 +174,41 @@ class Downloader void close(); - Download* startDownload(const std::string& uri, const std::vector>& options = {}); - Download* getDownload(const std::string& did); + /** + * Start a new download. + * + * This method is thread safe and return a pointer to a newly created `Download`. + * User should call `update` on the returned `Download` to have an accurate status. + * + * @param uri: The uri of the thing to download. + * @param options: A series of pair to pass to aria. + * @return: The newly created Download. + */ + std::shared_ptr startDownload(const std::string& uri, const std::vector>& options = {}); - size_t getNbDownload() { return m_knownDownloads.size(); } - std::vector getDownloadIds(); + /** + * Get a download corrsponding to a download id (did) + * User should call `update` on the returned `Download` to have an accurate status. + * + * @param did: The download id to search for. + * @return: The Download corresponding to did. + * @throw: Throw std::out_of_range if did is not found. + */ + std::shared_ptr getDownload(const std::string& did); + + /** + * Get the number of downloads currently managed. + */ + size_t getNbDownload() const; + + /** + * Get the ids of the managed downloads. + */ + std::vector getDownloadIds() const; private: - std::map> m_knownDownloads; + mutable std::mutex m_lock; + std::map> m_knownDownloads; std::shared_ptr mp_aria; }; } diff --git a/src/aria2.cpp b/src/aria2.cpp index 78f541128..e58b2bd67 100644 --- a/src/aria2.cpp +++ b/src/aria2.cpp @@ -24,7 +24,7 @@ #define LOG_ARIA_ERROR() \ { \ std::cerr << "ERROR: aria2 RPC request failed. (" << res << ")." << std::endl; \ - std::cerr << (m_curlErrorBuffer[0] ? m_curlErrorBuffer.get() : curl_easy_strerror(res)) << std::endl; \ + std::cerr << (curlErrorBuffer[0] ? curlErrorBuffer : curl_easy_strerror(res)) << std::endl; \ } namespace kiwix { @@ -32,9 +32,7 @@ namespace kiwix { Aria2::Aria2(): mp_aria(nullptr), m_port(42042), - m_secret(getNewRpcSecret()), - m_curlErrorBuffer(new char[CURL_ERROR_SIZE]), - mp_curl(nullptr) + m_secret(getNewRpcSecret()) { m_downloadDir = getDataDirectory(); makeDirectory(m_downloadDir); @@ -91,36 +89,32 @@ Aria2::Aria2(): launchCmd.append(cmd).append(" "); } mp_aria = Subprocess::run(callCmd); - mp_curl = curl_easy_init(); - curl_easy_setopt(mp_curl, CURLOPT_URL, "http://localhost/rpc"); - curl_easy_setopt(mp_curl, CURLOPT_PORT, m_port); - curl_easy_setopt(mp_curl, CURLOPT_POST, 1L); - curl_easy_setopt(mp_curl, CURLOPT_ERRORBUFFER, m_curlErrorBuffer.get()); + CURL* p_curl = curl_easy_init(); + char curlErrorBuffer[CURL_ERROR_SIZE]; + + curl_easy_setopt(p_curl, CURLOPT_URL, "http://localhost/rpc"); + curl_easy_setopt(p_curl, CURLOPT_PORT, m_port); + curl_easy_setopt(p_curl, CURLOPT_POST, 1L); + curl_easy_setopt(p_curl, CURLOPT_ERRORBUFFER, curlErrorBuffer); int watchdog = 50; while(--watchdog) { sleep(10); - m_curlErrorBuffer[0] = 0; - auto res = curl_easy_perform(mp_curl); + curlErrorBuffer[0] = 0; + auto res = curl_easy_perform(p_curl); if (res == CURLE_OK) { break; } else if (watchdog == 1) { LOG_ARIA_ERROR(); } } + curl_easy_cleanup(p_curl); if (!watchdog) { - curl_easy_cleanup(mp_curl); throw std::runtime_error("Cannot connect to aria2c rpc. Aria2c launch cmd : " + launchCmd); } } -Aria2::~Aria2() -{ - std::unique_lock lock(m_lock); - curl_easy_cleanup(mp_curl); -} - void Aria2::close() { saveSession(); @@ -140,20 +134,25 @@ std::string Aria2::doRequest(const MethodCall& methodCall) std::stringstream outStream; CURLcode res; long response_code; - { - std::unique_lock lock(m_lock); - curl_easy_setopt(mp_curl, CURLOPT_POSTFIELDSIZE, requestContent.size()); - curl_easy_setopt(mp_curl, CURLOPT_POSTFIELDS, requestContent.c_str()); - curl_easy_setopt(mp_curl, CURLOPT_WRITEFUNCTION, &write_callback_to_iss); - curl_easy_setopt(mp_curl, CURLOPT_WRITEDATA, &outStream); - m_curlErrorBuffer[0] = 0; - res = curl_easy_perform(mp_curl); - if (res != CURLE_OK) { - LOG_ARIA_ERROR(); - throw std::runtime_error("Cannot perform request"); - } - curl_easy_getinfo(mp_curl, CURLINFO_RESPONSE_CODE, &response_code); + char curlErrorBuffer[CURL_ERROR_SIZE]; + CURL* p_curl = curl_easy_init(); + curl_easy_setopt(p_curl, CURLOPT_URL, "http://localhost/rpc"); + curl_easy_setopt(p_curl, CURLOPT_PORT, m_port); + curl_easy_setopt(p_curl, CURLOPT_POST, 1L); + curl_easy_setopt(p_curl, CURLOPT_ERRORBUFFER, curlErrorBuffer); + curl_easy_setopt(p_curl, CURLOPT_POSTFIELDSIZE, requestContent.size()); + curl_easy_setopt(p_curl, CURLOPT_POSTFIELDS, requestContent.c_str()); + curl_easy_setopt(p_curl, CURLOPT_WRITEFUNCTION, &write_callback_to_iss); + curl_easy_setopt(p_curl, CURLOPT_WRITEDATA, &outStream); + curlErrorBuffer[0] = 0; + res = curl_easy_perform(p_curl); + if (res != CURLE_OK) { + LOG_ARIA_ERROR(); + curl_easy_cleanup(p_curl); + throw std::runtime_error("Cannot perform request"); } + curl_easy_getinfo(p_curl, CURLINFO_RESPONSE_CODE, &response_code); + curl_easy_cleanup(p_curl); auto responseContent = outStream.str(); if (response_code != 200) { diff --git a/src/aria2.h b/src/aria2.h index 47a3f826a..f6cd633b8 100644 --- a/src/aria2.h +++ b/src/aria2.h @@ -12,7 +12,6 @@ #include "xmlrpc.h" #include -#include #include namespace kiwix { @@ -24,15 +23,11 @@ class Aria2 int m_port; std::string m_secret; std::string m_downloadDir; - std::unique_ptr m_curlErrorBuffer; - CURL* mp_curl; - std::mutex m_lock; - std::string doRequest(const MethodCall& methodCall); public: Aria2(); - virtual ~Aria2(); + virtual ~Aria2() = default; void close(); std::string addUri(const std::vector& uri, const std::vector>& options = {}); diff --git a/src/downloader.cpp b/src/downloader.cpp index 9bffad385..d874ad899 100644 --- a/src/downloader.cpp +++ b/src/downloader.cpp @@ -127,22 +127,24 @@ void Download::cancelDownload() Downloader::Downloader() : mp_aria(new Aria2()) { - try { - for (auto gid : mp_aria->tellActive()) { - m_knownDownloads[gid] = std::unique_ptr(new Download(mp_aria, gid)); - m_knownDownloads[gid]->updateStatus(); - } - } catch (std::exception& e) { - std::cerr << "aria2 tellActive failed : " << e.what() << std::endl; - } try { for (auto gid : mp_aria->tellWaiting()) { m_knownDownloads[gid] = std::unique_ptr(new Download(mp_aria, gid)); - m_knownDownloads[gid]->updateStatus(); + m_knownDownloads[gid]->updateStatus(false); } } catch (std::exception& e) { std::cerr << "aria2 tellWaiting failed : " << e.what() << std::endl; } + try { + for (auto gid : mp_aria->tellActive()) { + if( m_knownDownloads.find(gid) == m_knownDownloads.end()) { + m_knownDownloads[gid] = std::unique_ptr(new Download(mp_aria, gid)); + m_knownDownloads[gid]->updateStatus(false); + } + } + } catch (std::exception& e) { + std::cerr << "aria2 tellActive failed : " << e.what() << std::endl; + } } /* Destructor */ @@ -155,7 +157,8 @@ void Downloader::close() mp_aria->close(); } -std::vector Downloader::getDownloadIds() { +std::vector Downloader::getDownloadIds() const { + std::unique_lock lock(m_lock); std::vector ret; for(auto& p:m_knownDownloads) { ret.push_back(p.first); @@ -163,42 +166,46 @@ std::vector Downloader::getDownloadIds() { return ret; } -Download* Downloader::startDownload(const std::string& uri, const std::vector>& options) +std::shared_ptr Downloader::startDownload(const std::string& uri, const std::vector>& options) { + std::unique_lock lock(m_lock); for (auto& p: m_knownDownloads) { auto& d = p.second; auto& uris = d->getUris(); if (std::find(uris.begin(), uris.end(), uri) != uris.end()) - return d.get(); + return d; } std::vector uris = {uri}; auto gid = mp_aria->addUri(uris, options); - m_knownDownloads[gid] = std::unique_ptr(new Download(mp_aria, gid)); - return m_knownDownloads[gid].get(); + m_knownDownloads[gid] = std::make_shared(mp_aria, gid); + return m_knownDownloads[gid]; } -Download* Downloader::getDownload(const std::string& did) +std::shared_ptr Downloader::getDownload(const std::string& did) { + std::unique_lock lock(m_lock); try { - m_knownDownloads.at(did).get()->updateStatus(true); - return m_knownDownloads.at(did).get(); + return m_knownDownloads.at(did); } catch(std::exception& e) { - for (auto gid : mp_aria->tellActive()) { - if (gid == did) { - m_knownDownloads[gid] = std::unique_ptr(new Download(mp_aria, gid)); - m_knownDownloads.at(gid).get()->updateStatus(true); - return m_knownDownloads[gid].get(); - } - } for (auto gid : mp_aria->tellWaiting()) { if (gid == did) { - m_knownDownloads[gid] = std::unique_ptr(new Download(mp_aria, gid)); - m_knownDownloads.at(gid).get()->updateStatus(true); - return m_knownDownloads[gid].get(); + m_knownDownloads[gid] = std::make_shared(mp_aria, gid); + return m_knownDownloads[gid]; } - } + } + for (auto gid : mp_aria->tellActive()) { + if (gid == did) { + m_knownDownloads[gid] = std::make_shared(mp_aria, gid); + return m_knownDownloads[gid]; + } + } throw e; } } +size_t Downloader::getNbDownload() const { + std::unique_lock lock(m_lock); + return m_knownDownloads.size(); +} + }