Merge pull request #886 from kiwix/thread_aria

This commit is contained in:
Matthieu Gautier 2023-02-08 16:21:52 +01:00 committed by GitHub
commit 2f419996ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 190 additions and 80 deletions

View File

@ -25,6 +25,7 @@
#include <map> #include <map>
#include <memory> #include <memory>
#include <stdexcept> #include <stdexcept>
#include <mutex>
namespace kiwix namespace kiwix
{ {
@ -43,6 +44,14 @@ class AriaError : public std::runtime_error {
}; };
/**
* A representation of a current download.
*
* `Download` is not thread safe. User must care to not call method on a
* same download from different threads.
* However, it is safe to use different `Download`s from different threads.
*/
class Download { class Download {
public: public:
typedef enum { K_ACTIVE, K_WAITING, K_PAUSED, K_ERROR, K_COMPLETE, K_REMOVED, K_UNKNOWN } StatusResult; typedef enum { K_ACTIVE, K_WAITING, K_PAUSED, K_ERROR, K_COMPLETE, K_REMOVED, K_UNKNOWN } StatusResult;
@ -53,19 +62,89 @@ class Download {
: mp_aria(p_aria), : mp_aria(p_aria),
m_status(K_UNKNOWN), m_status(K_UNKNOWN),
m_did(did) {}; m_did(did) {};
void updateStatus(bool follow=false);
/**
* Update the status of the download.
*
* This call make an aria rpc call and is blocking.
* Some download (started with a metalink) are in fact several downloads.
* - A first one to download the metadlink.
* - A second one to download the real file.
*
* If `follow` is true, updateStatus tries to detect that and tracks
* the second download when the first one is finished.
* By passing false to `follow`, `Download` will only track the first download.
*
* `getFoo` methods are based on the last statusUpdate.
*
* @param follow: Do we have to follow following downloads.
*/
void updateStatus(bool follow);
/**
* Pause the download (and call updateStatus)
*/
void pauseDownload(); void pauseDownload();
/**
* Resume the download (and call updateStatus)
*/
void resumeDownload(); void resumeDownload();
/**
* Cancel the download.
*
* A canceled downlod cannot be resume and updateStatus does nothing.
* However, you can still get information based on the last known information.
*/
void cancelDownload(); void cancelDownload();
StatusResult getStatus() { return m_status; }
std::string getDid() { return m_did; } /*
std::string getFollowedBy() { return m_followedBy; } * Get the status of the download.
uint64_t getTotalLength() { return m_totalLength; } */
uint64_t getCompletedLength() { return m_completedLength; } StatusResult getStatus() const { return m_status; }
uint64_t getDownloadSpeed() { return m_downloadSpeed; }
uint64_t getVerifiedLength() { return m_verifiedLength; } /*
std::string getPath() { return m_path; } * Get the id of the download.
std::vector<std::string>& getUris() { return m_uris; } */
const std::string& getDid() const { return m_did; }
/*
* Get the id of the "second" download.
*
* Set only if the "first" download is a metalink and is complete.
*/
const std::string& getFollowedBy() const { return m_followedBy; }
/*
* Get the total length of the download.
*/
uint64_t getTotalLength() const { return m_totalLength; }
/*
* Get the completed length of the download.
*/
uint64_t getCompletedLength() const { return m_completedLength; }
/*
* Get the download speed of the download.
*/
uint64_t getDownloadSpeed() const { return m_downloadSpeed; }
/*
* Get the verified length of the download.
*/
uint64_t getVerifiedLength() const { return m_verifiedLength; }
/*
* Get the path (local file) of the download.
*/
const std::string& getPath() const { return m_path; }
/*
* Get the download uris of the download.
*/
const std::vector<std::string>& getUris() const { return m_uris; }
protected: protected:
std::shared_ptr<Aria2> mp_aria; std::shared_ptr<Aria2> mp_aria;
@ -83,6 +162,9 @@ class Download {
/** /**
* A tool to download things. * A tool to download things.
* *
* A Downloader manages `Download` using aria2 in the background.
* `Downloader` is threadsafe.
* However, the returned `Download`s are NOT threadsafe.
*/ */
class Downloader class Downloader
{ {
@ -92,14 +174,41 @@ class Downloader
void close(); void close();
Download* startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options = {}); /**
Download* getDownload(const std::string& did); * Start a new download.
*
* This method is thread safe and return a pointer to a newly created `Download`.
* User should call `update` on the returned `Download` to have an accurate status.
*
* @param uri: The uri of the thing to download.
* @param options: A series of pair <option_name, option_value> to pass to aria.
* @return: The newly created Download.
*/
std::shared_ptr<Download> startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options = {});
size_t getNbDownload() { return m_knownDownloads.size(); } /**
std::vector<std::string> getDownloadIds(); * Get a download corrsponding to a download id (did)
* User should call `update` on the returned `Download` to have an accurate status.
*
* @param did: The download id to search for.
* @return: The Download corresponding to did.
* @throw: Throw std::out_of_range if did is not found.
*/
std::shared_ptr<Download> getDownload(const std::string& did);
/**
* Get the number of downloads currently managed.
*/
size_t getNbDownload() const;
/**
* Get the ids of the managed downloads.
*/
std::vector<std::string> getDownloadIds() const;
private: private:
std::map<std::string, std::unique_ptr<Download>> m_knownDownloads; mutable std::mutex m_lock;
std::map<std::string, std::shared_ptr<Download>> m_knownDownloads;
std::shared_ptr<Aria2> mp_aria; std::shared_ptr<Aria2> mp_aria;
}; };
} }

View File

@ -24,7 +24,7 @@
#define LOG_ARIA_ERROR() \ #define LOG_ARIA_ERROR() \
{ \ { \
std::cerr << "ERROR: aria2 RPC request failed. (" << res << ")." << std::endl; \ std::cerr << "ERROR: aria2 RPC request failed. (" << res << ")." << std::endl; \
std::cerr << (m_curlErrorBuffer[0] ? m_curlErrorBuffer.get() : curl_easy_strerror(res)) << std::endl; \ std::cerr << (curlErrorBuffer[0] ? curlErrorBuffer : curl_easy_strerror(res)) << std::endl; \
} }
namespace kiwix { namespace kiwix {
@ -32,9 +32,7 @@ namespace kiwix {
Aria2::Aria2(): Aria2::Aria2():
mp_aria(nullptr), mp_aria(nullptr),
m_port(42042), m_port(42042),
m_secret(getNewRpcSecret()), m_secret(getNewRpcSecret())
m_curlErrorBuffer(new char[CURL_ERROR_SIZE]),
mp_curl(nullptr)
{ {
m_downloadDir = getDataDirectory(); m_downloadDir = getDataDirectory();
makeDirectory(m_downloadDir); makeDirectory(m_downloadDir);
@ -91,36 +89,32 @@ Aria2::Aria2():
launchCmd.append(cmd).append(" "); launchCmd.append(cmd).append(" ");
} }
mp_aria = Subprocess::run(callCmd); mp_aria = Subprocess::run(callCmd);
mp_curl = curl_easy_init();
curl_easy_setopt(mp_curl, CURLOPT_URL, "http://localhost/rpc"); CURL* p_curl = curl_easy_init();
curl_easy_setopt(mp_curl, CURLOPT_PORT, m_port); char curlErrorBuffer[CURL_ERROR_SIZE];
curl_easy_setopt(mp_curl, CURLOPT_POST, 1L);
curl_easy_setopt(mp_curl, CURLOPT_ERRORBUFFER, m_curlErrorBuffer.get()); curl_easy_setopt(p_curl, CURLOPT_URL, "http://localhost/rpc");
curl_easy_setopt(p_curl, CURLOPT_PORT, m_port);
curl_easy_setopt(p_curl, CURLOPT_POST, 1L);
curl_easy_setopt(p_curl, CURLOPT_ERRORBUFFER, curlErrorBuffer);
int watchdog = 50; int watchdog = 50;
while(--watchdog) { while(--watchdog) {
sleep(10); sleep(10);
m_curlErrorBuffer[0] = 0; curlErrorBuffer[0] = 0;
auto res = curl_easy_perform(mp_curl); auto res = curl_easy_perform(p_curl);
if (res == CURLE_OK) { if (res == CURLE_OK) {
break; break;
} else if (watchdog == 1) { } else if (watchdog == 1) {
LOG_ARIA_ERROR(); LOG_ARIA_ERROR();
} }
} }
curl_easy_cleanup(p_curl);
if (!watchdog) { if (!watchdog) {
curl_easy_cleanup(mp_curl);
throw std::runtime_error("Cannot connect to aria2c rpc. Aria2c launch cmd : " + launchCmd); throw std::runtime_error("Cannot connect to aria2c rpc. Aria2c launch cmd : " + launchCmd);
} }
} }
Aria2::~Aria2()
{
std::unique_lock<std::mutex> lock(m_lock);
curl_easy_cleanup(mp_curl);
}
void Aria2::close() void Aria2::close()
{ {
saveSession(); saveSession();
@ -140,20 +134,25 @@ std::string Aria2::doRequest(const MethodCall& methodCall)
std::stringstream outStream; std::stringstream outStream;
CURLcode res; CURLcode res;
long response_code; long response_code;
{ char curlErrorBuffer[CURL_ERROR_SIZE];
std::unique_lock<std::mutex> lock(m_lock); CURL* p_curl = curl_easy_init();
curl_easy_setopt(mp_curl, CURLOPT_POSTFIELDSIZE, requestContent.size()); curl_easy_setopt(p_curl, CURLOPT_URL, "http://localhost/rpc");
curl_easy_setopt(mp_curl, CURLOPT_POSTFIELDS, requestContent.c_str()); curl_easy_setopt(p_curl, CURLOPT_PORT, m_port);
curl_easy_setopt(mp_curl, CURLOPT_WRITEFUNCTION, &write_callback_to_iss); curl_easy_setopt(p_curl, CURLOPT_POST, 1L);
curl_easy_setopt(mp_curl, CURLOPT_WRITEDATA, &outStream); curl_easy_setopt(p_curl, CURLOPT_ERRORBUFFER, curlErrorBuffer);
m_curlErrorBuffer[0] = 0; curl_easy_setopt(p_curl, CURLOPT_POSTFIELDSIZE, requestContent.size());
res = curl_easy_perform(mp_curl); curl_easy_setopt(p_curl, CURLOPT_POSTFIELDS, requestContent.c_str());
curl_easy_setopt(p_curl, CURLOPT_WRITEFUNCTION, &write_callback_to_iss);
curl_easy_setopt(p_curl, CURLOPT_WRITEDATA, &outStream);
curlErrorBuffer[0] = 0;
res = curl_easy_perform(p_curl);
if (res != CURLE_OK) { if (res != CURLE_OK) {
LOG_ARIA_ERROR(); LOG_ARIA_ERROR();
curl_easy_cleanup(p_curl);
throw std::runtime_error("Cannot perform request"); throw std::runtime_error("Cannot perform request");
} }
curl_easy_getinfo(mp_curl, CURLINFO_RESPONSE_CODE, &response_code); curl_easy_getinfo(p_curl, CURLINFO_RESPONSE_CODE, &response_code);
} curl_easy_cleanup(p_curl);
auto responseContent = outStream.str(); auto responseContent = outStream.str();
if (response_code != 200) { if (response_code != 200) {

View File

@ -12,7 +12,6 @@
#include "xmlrpc.h" #include "xmlrpc.h"
#include <memory> #include <memory>
#include <mutex>
#include <curl/curl.h> #include <curl/curl.h>
namespace kiwix { namespace kiwix {
@ -24,15 +23,11 @@ class Aria2
int m_port; int m_port;
std::string m_secret; std::string m_secret;
std::string m_downloadDir; std::string m_downloadDir;
std::unique_ptr<char[]> m_curlErrorBuffer;
CURL* mp_curl;
std::mutex m_lock;
std::string doRequest(const MethodCall& methodCall); std::string doRequest(const MethodCall& methodCall);
public: public:
Aria2(); Aria2();
virtual ~Aria2(); virtual ~Aria2() = default;
void close(); void close();
std::string addUri(const std::vector<std::string>& uri, const std::vector<std::pair<std::string, std::string>>& options = {}); std::string addUri(const std::vector<std::string>& uri, const std::vector<std::pair<std::string, std::string>>& options = {});

View File

@ -127,22 +127,24 @@ void Download::cancelDownload()
Downloader::Downloader() : Downloader::Downloader() :
mp_aria(new Aria2()) mp_aria(new Aria2())
{ {
try {
for (auto gid : mp_aria->tellActive()) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads[gid]->updateStatus();
}
} catch (std::exception& e) {
std::cerr << "aria2 tellActive failed : " << e.what() << std::endl;
}
try { try {
for (auto gid : mp_aria->tellWaiting()) { for (auto gid : mp_aria->tellWaiting()) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid)); m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads[gid]->updateStatus(); m_knownDownloads[gid]->updateStatus(false);
} }
} catch (std::exception& e) { } catch (std::exception& e) {
std::cerr << "aria2 tellWaiting failed : " << e.what() << std::endl; std::cerr << "aria2 tellWaiting failed : " << e.what() << std::endl;
} }
try {
for (auto gid : mp_aria->tellActive()) {
if( m_knownDownloads.find(gid) == m_knownDownloads.end()) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads[gid]->updateStatus(false);
}
}
} catch (std::exception& e) {
std::cerr << "aria2 tellActive failed : " << e.what() << std::endl;
}
} }
/* Destructor */ /* Destructor */
@ -155,7 +157,8 @@ void Downloader::close()
mp_aria->close(); mp_aria->close();
} }
std::vector<std::string> Downloader::getDownloadIds() { std::vector<std::string> Downloader::getDownloadIds() const {
std::unique_lock<std::mutex> lock(m_lock);
std::vector<std::string> ret; std::vector<std::string> ret;
for(auto& p:m_knownDownloads) { for(auto& p:m_knownDownloads) {
ret.push_back(p.first); ret.push_back(p.first);
@ -163,42 +166,46 @@ std::vector<std::string> Downloader::getDownloadIds() {
return ret; return ret;
} }
Download* Downloader::startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options) std::shared_ptr<Download> Downloader::startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options)
{ {
std::unique_lock<std::mutex> lock(m_lock);
for (auto& p: m_knownDownloads) { for (auto& p: m_knownDownloads) {
auto& d = p.second; auto& d = p.second;
auto& uris = d->getUris(); auto& uris = d->getUris();
if (std::find(uris.begin(), uris.end(), uri) != uris.end()) if (std::find(uris.begin(), uris.end(), uri) != uris.end())
return d.get(); return d;
} }
std::vector<std::string> uris = {uri}; std::vector<std::string> uris = {uri};
auto gid = mp_aria->addUri(uris, options); auto gid = mp_aria->addUri(uris, options);
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid)); m_knownDownloads[gid] = std::make_shared<Download>(mp_aria, gid);
return m_knownDownloads[gid].get(); return m_knownDownloads[gid];
} }
Download* Downloader::getDownload(const std::string& did) std::shared_ptr<Download> Downloader::getDownload(const std::string& did)
{ {
std::unique_lock<std::mutex> lock(m_lock);
try { try {
m_knownDownloads.at(did).get()->updateStatus(true); return m_knownDownloads.at(did);
return m_knownDownloads.at(did).get();
} catch(std::exception& e) { } catch(std::exception& e) {
for (auto gid : mp_aria->tellActive()) {
if (gid == did) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads.at(gid).get()->updateStatus(true);
return m_knownDownloads[gid].get();
}
}
for (auto gid : mp_aria->tellWaiting()) { for (auto gid : mp_aria->tellWaiting()) {
if (gid == did) { if (gid == did) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid)); m_knownDownloads[gid] = std::make_shared<Download>(mp_aria, gid);
m_knownDownloads.at(gid).get()->updateStatus(true); return m_knownDownloads[gid];
return m_knownDownloads[gid].get(); }
}
for (auto gid : mp_aria->tellActive()) {
if (gid == did) {
m_knownDownloads[gid] = std::make_shared<Download>(mp_aria, gid);
return m_knownDownloads[gid];
} }
} }
throw e; throw e;
} }
} }
size_t Downloader::getNbDownload() const {
std::unique_lock<std::mutex> lock(m_lock);
return m_knownDownloads.size();
}
} }