Merge pull request #886 from kiwix/thread_aria

This commit is contained in:
Matthieu Gautier 2023-02-08 16:21:52 +01:00 committed by GitHub
commit 2f419996ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 190 additions and 80 deletions

View File

@ -25,6 +25,7 @@
#include <map>
#include <memory>
#include <stdexcept>
#include <mutex>
namespace kiwix
{
@ -43,6 +44,14 @@ class AriaError : public std::runtime_error {
};
/**
* A representation of a current download.
*
* `Download` is not thread safe. User must care to not call method on a
* same download from different threads.
* However, it is safe to use different `Download`s from different threads.
*/
class Download {
public:
typedef enum { K_ACTIVE, K_WAITING, K_PAUSED, K_ERROR, K_COMPLETE, K_REMOVED, K_UNKNOWN } StatusResult;
@ -53,19 +62,89 @@ class Download {
: mp_aria(p_aria),
m_status(K_UNKNOWN),
m_did(did) {};
void updateStatus(bool follow=false);
/**
* Update the status of the download.
*
* This call make an aria rpc call and is blocking.
* Some download (started with a metalink) are in fact several downloads.
* - A first one to download the metadlink.
* - A second one to download the real file.
*
* If `follow` is true, updateStatus tries to detect that and tracks
* the second download when the first one is finished.
* By passing false to `follow`, `Download` will only track the first download.
*
* `getFoo` methods are based on the last statusUpdate.
*
* @param follow: Do we have to follow following downloads.
*/
void updateStatus(bool follow);
/**
* Pause the download (and call updateStatus)
*/
void pauseDownload();
/**
* Resume the download (and call updateStatus)
*/
void resumeDownload();
/**
* Cancel the download.
*
* A canceled downlod cannot be resume and updateStatus does nothing.
* However, you can still get information based on the last known information.
*/
void cancelDownload();
StatusResult getStatus() { return m_status; }
std::string getDid() { return m_did; }
std::string getFollowedBy() { return m_followedBy; }
uint64_t getTotalLength() { return m_totalLength; }
uint64_t getCompletedLength() { return m_completedLength; }
uint64_t getDownloadSpeed() { return m_downloadSpeed; }
uint64_t getVerifiedLength() { return m_verifiedLength; }
std::string getPath() { return m_path; }
std::vector<std::string>& getUris() { return m_uris; }
/*
* Get the status of the download.
*/
StatusResult getStatus() const { return m_status; }
/*
* Get the id of the download.
*/
const std::string& getDid() const { return m_did; }
/*
* Get the id of the "second" download.
*
* Set only if the "first" download is a metalink and is complete.
*/
const std::string& getFollowedBy() const { return m_followedBy; }
/*
* Get the total length of the download.
*/
uint64_t getTotalLength() const { return m_totalLength; }
/*
* Get the completed length of the download.
*/
uint64_t getCompletedLength() const { return m_completedLength; }
/*
* Get the download speed of the download.
*/
uint64_t getDownloadSpeed() const { return m_downloadSpeed; }
/*
* Get the verified length of the download.
*/
uint64_t getVerifiedLength() const { return m_verifiedLength; }
/*
* Get the path (local file) of the download.
*/
const std::string& getPath() const { return m_path; }
/*
* Get the download uris of the download.
*/
const std::vector<std::string>& getUris() const { return m_uris; }
protected:
std::shared_ptr<Aria2> mp_aria;
@ -83,6 +162,9 @@ class Download {
/**
* A tool to download things.
*
* A Downloader manages `Download` using aria2 in the background.
* `Downloader` is threadsafe.
* However, the returned `Download`s are NOT threadsafe.
*/
class Downloader
{
@ -92,14 +174,41 @@ class Downloader
void close();
Download* startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options = {});
Download* getDownload(const std::string& did);
/**
* Start a new download.
*
* This method is thread safe and return a pointer to a newly created `Download`.
* User should call `update` on the returned `Download` to have an accurate status.
*
* @param uri: The uri of the thing to download.
* @param options: A series of pair <option_name, option_value> to pass to aria.
* @return: The newly created Download.
*/
std::shared_ptr<Download> startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options = {});
size_t getNbDownload() { return m_knownDownloads.size(); }
std::vector<std::string> getDownloadIds();
/**
* Get a download corrsponding to a download id (did)
* User should call `update` on the returned `Download` to have an accurate status.
*
* @param did: The download id to search for.
* @return: The Download corresponding to did.
* @throw: Throw std::out_of_range if did is not found.
*/
std::shared_ptr<Download> getDownload(const std::string& did);
/**
* Get the number of downloads currently managed.
*/
size_t getNbDownload() const;
/**
* Get the ids of the managed downloads.
*/
std::vector<std::string> getDownloadIds() const;
private:
std::map<std::string, std::unique_ptr<Download>> m_knownDownloads;
mutable std::mutex m_lock;
std::map<std::string, std::shared_ptr<Download>> m_knownDownloads;
std::shared_ptr<Aria2> mp_aria;
};
}

View File

@ -24,7 +24,7 @@
#define LOG_ARIA_ERROR() \
{ \
std::cerr << "ERROR: aria2 RPC request failed. (" << res << ")." << std::endl; \
std::cerr << (m_curlErrorBuffer[0] ? m_curlErrorBuffer.get() : curl_easy_strerror(res)) << std::endl; \
std::cerr << (curlErrorBuffer[0] ? curlErrorBuffer : curl_easy_strerror(res)) << std::endl; \
}
namespace kiwix {
@ -32,9 +32,7 @@ namespace kiwix {
Aria2::Aria2():
mp_aria(nullptr),
m_port(42042),
m_secret(getNewRpcSecret()),
m_curlErrorBuffer(new char[CURL_ERROR_SIZE]),
mp_curl(nullptr)
m_secret(getNewRpcSecret())
{
m_downloadDir = getDataDirectory();
makeDirectory(m_downloadDir);
@ -91,36 +89,32 @@ Aria2::Aria2():
launchCmd.append(cmd).append(" ");
}
mp_aria = Subprocess::run(callCmd);
mp_curl = curl_easy_init();
curl_easy_setopt(mp_curl, CURLOPT_URL, "http://localhost/rpc");
curl_easy_setopt(mp_curl, CURLOPT_PORT, m_port);
curl_easy_setopt(mp_curl, CURLOPT_POST, 1L);
curl_easy_setopt(mp_curl, CURLOPT_ERRORBUFFER, m_curlErrorBuffer.get());
CURL* p_curl = curl_easy_init();
char curlErrorBuffer[CURL_ERROR_SIZE];
curl_easy_setopt(p_curl, CURLOPT_URL, "http://localhost/rpc");
curl_easy_setopt(p_curl, CURLOPT_PORT, m_port);
curl_easy_setopt(p_curl, CURLOPT_POST, 1L);
curl_easy_setopt(p_curl, CURLOPT_ERRORBUFFER, curlErrorBuffer);
int watchdog = 50;
while(--watchdog) {
sleep(10);
m_curlErrorBuffer[0] = 0;
auto res = curl_easy_perform(mp_curl);
curlErrorBuffer[0] = 0;
auto res = curl_easy_perform(p_curl);
if (res == CURLE_OK) {
break;
} else if (watchdog == 1) {
LOG_ARIA_ERROR();
}
}
curl_easy_cleanup(p_curl);
if (!watchdog) {
curl_easy_cleanup(mp_curl);
throw std::runtime_error("Cannot connect to aria2c rpc. Aria2c launch cmd : " + launchCmd);
}
}
Aria2::~Aria2()
{
std::unique_lock<std::mutex> lock(m_lock);
curl_easy_cleanup(mp_curl);
}
void Aria2::close()
{
saveSession();
@ -140,20 +134,25 @@ std::string Aria2::doRequest(const MethodCall& methodCall)
std::stringstream outStream;
CURLcode res;
long response_code;
{
std::unique_lock<std::mutex> lock(m_lock);
curl_easy_setopt(mp_curl, CURLOPT_POSTFIELDSIZE, requestContent.size());
curl_easy_setopt(mp_curl, CURLOPT_POSTFIELDS, requestContent.c_str());
curl_easy_setopt(mp_curl, CURLOPT_WRITEFUNCTION, &write_callback_to_iss);
curl_easy_setopt(mp_curl, CURLOPT_WRITEDATA, &outStream);
m_curlErrorBuffer[0] = 0;
res = curl_easy_perform(mp_curl);
if (res != CURLE_OK) {
LOG_ARIA_ERROR();
throw std::runtime_error("Cannot perform request");
}
curl_easy_getinfo(mp_curl, CURLINFO_RESPONSE_CODE, &response_code);
char curlErrorBuffer[CURL_ERROR_SIZE];
CURL* p_curl = curl_easy_init();
curl_easy_setopt(p_curl, CURLOPT_URL, "http://localhost/rpc");
curl_easy_setopt(p_curl, CURLOPT_PORT, m_port);
curl_easy_setopt(p_curl, CURLOPT_POST, 1L);
curl_easy_setopt(p_curl, CURLOPT_ERRORBUFFER, curlErrorBuffer);
curl_easy_setopt(p_curl, CURLOPT_POSTFIELDSIZE, requestContent.size());
curl_easy_setopt(p_curl, CURLOPT_POSTFIELDS, requestContent.c_str());
curl_easy_setopt(p_curl, CURLOPT_WRITEFUNCTION, &write_callback_to_iss);
curl_easy_setopt(p_curl, CURLOPT_WRITEDATA, &outStream);
curlErrorBuffer[0] = 0;
res = curl_easy_perform(p_curl);
if (res != CURLE_OK) {
LOG_ARIA_ERROR();
curl_easy_cleanup(p_curl);
throw std::runtime_error("Cannot perform request");
}
curl_easy_getinfo(p_curl, CURLINFO_RESPONSE_CODE, &response_code);
curl_easy_cleanup(p_curl);
auto responseContent = outStream.str();
if (response_code != 200) {

View File

@ -12,7 +12,6 @@
#include "xmlrpc.h"
#include <memory>
#include <mutex>
#include <curl/curl.h>
namespace kiwix {
@ -24,15 +23,11 @@ class Aria2
int m_port;
std::string m_secret;
std::string m_downloadDir;
std::unique_ptr<char[]> m_curlErrorBuffer;
CURL* mp_curl;
std::mutex m_lock;
std::string doRequest(const MethodCall& methodCall);
public:
Aria2();
virtual ~Aria2();
virtual ~Aria2() = default;
void close();
std::string addUri(const std::vector<std::string>& uri, const std::vector<std::pair<std::string, std::string>>& options = {});

View File

@ -127,22 +127,24 @@ void Download::cancelDownload()
Downloader::Downloader() :
mp_aria(new Aria2())
{
try {
for (auto gid : mp_aria->tellActive()) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads[gid]->updateStatus();
}
} catch (std::exception& e) {
std::cerr << "aria2 tellActive failed : " << e.what() << std::endl;
}
try {
for (auto gid : mp_aria->tellWaiting()) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads[gid]->updateStatus();
m_knownDownloads[gid]->updateStatus(false);
}
} catch (std::exception& e) {
std::cerr << "aria2 tellWaiting failed : " << e.what() << std::endl;
}
try {
for (auto gid : mp_aria->tellActive()) {
if( m_knownDownloads.find(gid) == m_knownDownloads.end()) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads[gid]->updateStatus(false);
}
}
} catch (std::exception& e) {
std::cerr << "aria2 tellActive failed : " << e.what() << std::endl;
}
}
/* Destructor */
@ -155,7 +157,8 @@ void Downloader::close()
mp_aria->close();
}
std::vector<std::string> Downloader::getDownloadIds() {
std::vector<std::string> Downloader::getDownloadIds() const {
std::unique_lock<std::mutex> lock(m_lock);
std::vector<std::string> ret;
for(auto& p:m_knownDownloads) {
ret.push_back(p.first);
@ -163,42 +166,46 @@ std::vector<std::string> Downloader::getDownloadIds() {
return ret;
}
Download* Downloader::startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options)
std::shared_ptr<Download> Downloader::startDownload(const std::string& uri, const std::vector<std::pair<std::string, std::string>>& options)
{
std::unique_lock<std::mutex> lock(m_lock);
for (auto& p: m_knownDownloads) {
auto& d = p.second;
auto& uris = d->getUris();
if (std::find(uris.begin(), uris.end(), uri) != uris.end())
return d.get();
return d;
}
std::vector<std::string> uris = {uri};
auto gid = mp_aria->addUri(uris, options);
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
return m_knownDownloads[gid].get();
m_knownDownloads[gid] = std::make_shared<Download>(mp_aria, gid);
return m_knownDownloads[gid];
}
Download* Downloader::getDownload(const std::string& did)
std::shared_ptr<Download> Downloader::getDownload(const std::string& did)
{
std::unique_lock<std::mutex> lock(m_lock);
try {
m_knownDownloads.at(did).get()->updateStatus(true);
return m_knownDownloads.at(did).get();
return m_knownDownloads.at(did);
} catch(std::exception& e) {
for (auto gid : mp_aria->tellActive()) {
if (gid == did) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads.at(gid).get()->updateStatus(true);
return m_knownDownloads[gid].get();
}
}
for (auto gid : mp_aria->tellWaiting()) {
if (gid == did) {
m_knownDownloads[gid] = std::unique_ptr<Download>(new Download(mp_aria, gid));
m_knownDownloads.at(gid).get()->updateStatus(true);
return m_knownDownloads[gid].get();
m_knownDownloads[gid] = std::make_shared<Download>(mp_aria, gid);
return m_knownDownloads[gid];
}
}
}
for (auto gid : mp_aria->tellActive()) {
if (gid == did) {
m_knownDownloads[gid] = std::make_shared<Download>(mp_aria, gid);
return m_knownDownloads[gid];
}
}
throw e;
}
}
size_t Downloader::getNbDownload() const {
std::unique_lock<std::mutex> lock(m_lock);
return m_knownDownloads.size();
}
}