Files
aria2/src/DHTMessageFactoryImpl.cc
Tatsuhiro Tsujikawa 3e00be26e8 Include local IP address to DHT GetPeers reply message
This is required when DHT is used in private network, where seeder is
initially not known to all nodes.  This functionality requires
--bt-external-ip is set since aria2 cannot figure out external IP
address to advertise in general.
2016-05-27 23:28:53 +09:00

541 lines
17 KiB
C++

/* <!-- copyright */
/*
* aria2 - The high speed download utility
*
* Copyright (C) 2006 Tatsuhiro Tsujikawa
*
* This program is free software; you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation; either version 2 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program; if not, write to the Free Software
* Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA
*
* In addition, as a special exception, the copyright holders give
* permission to link the code of portions of this program with the
* OpenSSL library under certain conditions as described in each
* individual source file, and distribute linked combinations
* including the two.
* You must obey the GNU General Public License in all respects
* for all of the code used other than OpenSSL. If you modify
* file(s) with this exception, you may extend this exception to your
* version of the file(s), but you are not obligated to do so. If you
* do not wish to do so, delete this exception statement from your
* version. If you delete this exception statement from all source
* files in the program, then also delete it here.
*/
/* copyright --> */
#include "DHTMessageFactoryImpl.h"
#include <cstring>
#include <utility>
#include "LogFactory.h"
#include "DlAbortEx.h"
#include "DHTNode.h"
#include "DHTRoutingTable.h"
#include "DHTPingMessage.h"
#include "DHTPingReplyMessage.h"
#include "DHTFindNodeMessage.h"
#include "DHTFindNodeReplyMessage.h"
#include "DHTGetPeersMessage.h"
#include "DHTGetPeersReplyMessage.h"
#include "DHTAnnouncePeerMessage.h"
#include "DHTAnnouncePeerReplyMessage.h"
#include "DHTUnknownMessage.h"
#include "DHTConnection.h"
#include "DHTMessageDispatcher.h"
#include "DHTPeerAnnounceStorage.h"
#include "DHTTokenTracker.h"
#include "DHTMessageCallback.h"
#include "bittorrent_helper.h"
#include "BtRuntime.h"
#include "util.h"
#include "Peer.h"
#include "Logger.h"
#include "fmt.h"
namespace aria2 {
DHTMessageFactoryImpl::DHTMessageFactoryImpl(int family)
: family_{family},
connection_{nullptr},
dispatcher_{nullptr},
routingTable_{nullptr},
peerAnnounceStorage_{nullptr},
tokenTracker_{nullptr},
btRegistry_{nullptr}
{
}
std::shared_ptr<DHTNode> DHTMessageFactoryImpl::getRemoteNode(
const unsigned char* id, const std::string& ipaddr, uint16_t port) const
{
auto node = routingTable_->getNode(id, ipaddr, port);
if (!node) {
node = std::make_shared<DHTNode>(id);
node->setIPAddress(ipaddr);
node->setPort(port);
}
return node;
}
namespace {
const Dict* getDictionary(const Dict* dict, const std::string& key)
{
const Dict* d = downcast<Dict>(dict->get(key));
if (d) {
return d;
}
else {
throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str()));
}
}
} // namespace
namespace {
const String* getString(const Dict* dict, const std::string& key)
{
const String* c = downcast<String>(dict->get(key));
if (c) {
return c;
}
else {
throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str()));
}
}
} // namespace
namespace {
const Integer* getInteger(const Dict* dict, const std::string& key)
{
const Integer* c = downcast<Integer>(dict->get(key));
if (c) {
return c;
}
else {
throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str()));
}
}
} // namespace
namespace {
const String* getString(const List* list, size_t index)
{
const String* c = downcast<String>(list->get(index));
if (c) {
return c;
}
else {
throw DL_ABORT_EX(fmt("Malformed DHT message. element[%lu] is not String.",
static_cast<unsigned long>(index)));
}
}
} // namespace
namespace {
const Integer* getInteger(const List* list, size_t index)
{
const Integer* c = downcast<Integer>(list->get(index));
if (c) {
return c;
}
else {
throw DL_ABORT_EX(fmt("Malformed DHT message. element[%lu] is not Integer.",
static_cast<unsigned long>(index)));
}
}
} // namespace
namespace {
const List* getList(const Dict* dict, const std::string& key)
{
const List* l = downcast<List>(dict->get(key));
if (l) {
return l;
}
else {
throw DL_ABORT_EX(fmt("Malformed DHT message. Missing %s", key.c_str()));
}
}
} // namespace
void DHTMessageFactoryImpl::validateID(const String* id) const
{
if (id->s().size() != DHT_ID_LENGTH) {
throw DL_ABORT_EX(fmt("Malformed DHT message. Invalid ID length."
" Expected:%lu, Actual:%lu",
static_cast<unsigned long>(DHT_ID_LENGTH),
static_cast<unsigned long>(id->s().size())));
}
}
void DHTMessageFactoryImpl::validatePort(const Integer* port) const
{
if (!(0 < port->i() && port->i() < UINT16_MAX)) {
throw DL_ABORT_EX(
fmt("Malformed DHT message. Invalid port=%" PRId64 "", port->i()));
}
}
namespace {
void setVersion(DHTMessage* msg, const Dict* dict)
{
const String* v = downcast<String>(dict->get(DHTMessage::V));
if (v) {
msg->setVersion(v->s());
}
else {
msg->setVersion(A2STR::NIL);
}
}
} // namespace
std::unique_ptr<DHTQueryMessage> DHTMessageFactoryImpl::createQueryMessage(
const Dict* dict, const std::string& ipaddr, uint16_t port)
{
const String* messageType = getString(dict, DHTQueryMessage::Q);
const String* transactionID = getString(dict, DHTMessage::T);
const String* y = getString(dict, DHTMessage::Y);
const Dict* aDict = getDictionary(dict, DHTQueryMessage::A);
if (y->s() != DHTQueryMessage::Q) {
throw DL_ABORT_EX("Malformed DHT message. y != q");
}
const String* id = getString(aDict, DHTMessage::ID);
validateID(id);
auto remoteNode = getRemoteNode(id->uc(), ipaddr, port);
auto msg = std::unique_ptr<DHTQueryMessage>{};
if (messageType->s() == DHTPingMessage::PING) {
msg = createPingMessage(remoteNode, transactionID->s());
}
else if (messageType->s() == DHTFindNodeMessage::FIND_NODE) {
const String* targetNodeID =
getString(aDict, DHTFindNodeMessage::TARGET_NODE);
validateID(targetNodeID);
msg = createFindNodeMessage(remoteNode, targetNodeID->uc(),
transactionID->s());
}
else if (messageType->s() == DHTGetPeersMessage::GET_PEERS) {
const String* infoHash = getString(aDict, DHTGetPeersMessage::INFO_HASH);
validateID(infoHash);
msg = createGetPeersMessage(remoteNode, infoHash->uc(), transactionID->s());
}
else if (messageType->s() == DHTAnnouncePeerMessage::ANNOUNCE_PEER) {
const String* infoHash =
getString(aDict, DHTAnnouncePeerMessage::INFO_HASH);
validateID(infoHash);
const Integer* port = getInteger(aDict, DHTAnnouncePeerMessage::PORT);
validatePort(port);
const String* token = getString(aDict, DHTAnnouncePeerMessage::TOKEN);
msg = createAnnouncePeerMessage(remoteNode, infoHash->uc(),
static_cast<uint16_t>(port->i()),
token->s(), transactionID->s());
}
else {
throw DL_ABORT_EX(
fmt("Unsupported message type: %s", messageType->s().c_str()));
}
setVersion(msg.get(), dict);
return msg;
}
std::unique_ptr<DHTResponseMessage>
DHTMessageFactoryImpl::createResponseMessage(const std::string& messageType,
const Dict* dict,
const std::string& ipaddr,
uint16_t port)
{
const String* transactionID = getString(dict, DHTMessage::T);
const String* y = getString(dict, DHTMessage::Y);
if (y->s() == DHTUnknownMessage::E) {
// for now, just report error message arrived and throw exception.
const List* e = getList(dict, DHTUnknownMessage::E);
if (e->size() == 2) {
A2_LOG_INFO(fmt("Received Error DHT message. code=%" PRId64 ", msg=%s",
getInteger(e, 0)->i(),
util::percentEncode(getString(e, 1)->s()).c_str()));
}
else {
A2_LOG_DEBUG("e doesn't have 2 elements.");
}
throw DL_ABORT_EX("Received Error DHT message.");
}
else if (y->s() != DHTResponseMessage::R) {
throw DL_ABORT_EX(fmt("Malformed DHT message. y != r: y=%s",
util::percentEncode(y->s()).c_str()));
}
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
const String* id = getString(rDict, DHTMessage::ID);
validateID(id);
auto remoteNode = getRemoteNode(id->uc(), ipaddr, port);
auto msg = std::unique_ptr<DHTResponseMessage>{};
if (messageType == DHTPingReplyMessage::PING) {
msg = createPingReplyMessage(remoteNode, id->uc(), transactionID->s());
}
else if (messageType == DHTFindNodeReplyMessage::FIND_NODE) {
msg = createFindNodeReplyMessage(remoteNode, dict, transactionID->s());
}
else if (messageType == DHTGetPeersReplyMessage::GET_PEERS) {
msg = createGetPeersReplyMessage(remoteNode, dict, transactionID->s());
}
else if (messageType == DHTAnnouncePeerReplyMessage::ANNOUNCE_PEER) {
msg = createAnnouncePeerReplyMessage(remoteNode, transactionID->s());
}
else {
throw DL_ABORT_EX(fmt("Unsupported message type: %s", messageType.c_str()));
}
setVersion(msg.get(), dict);
return msg;
}
namespace {
const std::string& getDefaultVersion()
{
static std::string version;
if (version.empty()) {
uint16_t vnum16 = htons(DHT_VERSION);
unsigned char buf[] = {'A', '2', 0, 0};
char* vnump = reinterpret_cast<char*>(&vnum16);
memcpy(buf + 2, vnump, 2);
version.assign(&buf[0], &buf[4]);
}
return version;
}
} // namespace
void DHTMessageFactoryImpl::setCommonProperty(DHTAbstractMessage* m)
{
m->setConnection(connection_);
m->setMessageDispatcher(dispatcher_);
m->setRoutingTable(routingTable_);
m->setMessageFactory(this);
m->setVersion(getDefaultVersion());
}
std::unique_ptr<DHTPingMessage> DHTMessageFactoryImpl::createPingMessage(
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID)
{
auto m = make_unique<DHTPingMessage>(localNode_, remoteNode, transactionID);
setCommonProperty(m.get());
return m;
}
std::unique_ptr<DHTPingReplyMessage>
DHTMessageFactoryImpl::createPingReplyMessage(
const std::shared_ptr<DHTNode>& remoteNode, const unsigned char* id,
const std::string& transactionID)
{
auto m = make_unique<DHTPingReplyMessage>(localNode_, remoteNode, id,
transactionID);
setCommonProperty(m.get());
return m;
}
std::unique_ptr<DHTFindNodeMessage>
DHTMessageFactoryImpl::createFindNodeMessage(
const std::shared_ptr<DHTNode>& remoteNode,
const unsigned char* targetNodeID, const std::string& transactionID)
{
auto m = make_unique<DHTFindNodeMessage>(localNode_, remoteNode, targetNodeID,
transactionID);
setCommonProperty(m.get());
return m;
}
std::unique_ptr<DHTFindNodeReplyMessage>
DHTMessageFactoryImpl::createFindNodeReplyMessage(
const std::shared_ptr<DHTNode>& remoteNode,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
const std::string& transactionID)
{
auto m = make_unique<DHTFindNodeReplyMessage>(family_, localNode_, remoteNode,
transactionID);
m->setClosestKNodes(std::move(closestKNodes));
setCommonProperty(m.get());
return m;
}
void DHTMessageFactoryImpl::extractNodes(
std::vector<std::shared_ptr<DHTNode>>& nodes, const unsigned char* src,
size_t length)
{
int unit = bittorrent::getCompactLength(family_) + 20;
if (length % unit != 0) {
throw DL_ABORT_EX(fmt("Nodes length is not multiple of %d", unit));
}
for (size_t offset = 0; offset < length; offset += unit) {
auto node = std::make_shared<DHTNode>(src + offset);
auto addr =
bittorrent::unpackcompact(src + offset + DHT_ID_LENGTH, family_);
if (addr.first.empty()) {
continue;
}
node->setIPAddress(addr.first);
node->setPort(addr.second);
nodes.push_back(std::move(node));
}
}
std::unique_ptr<DHTFindNodeReplyMessage>
DHTMessageFactoryImpl::createFindNodeReplyMessage(
const std::shared_ptr<DHTNode>& remoteNode, const Dict* dict,
const std::string& transactionID)
{
const String* nodesData = downcast<String>(
getDictionary(dict, DHTResponseMessage::R)
->get(family_ == AF_INET ? DHTFindNodeReplyMessage::NODES
: DHTFindNodeReplyMessage::NODES6));
std::vector<std::shared_ptr<DHTNode>> nodes;
if (nodesData) {
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
}
return createFindNodeReplyMessage(remoteNode, std::move(nodes),
transactionID);
}
std::unique_ptr<DHTGetPeersMessage>
DHTMessageFactoryImpl::createGetPeersMessage(
const std::shared_ptr<DHTNode>& remoteNode, const unsigned char* infoHash,
const std::string& transactionID)
{
auto m = make_unique<DHTGetPeersMessage>(localNode_, remoteNode, infoHash,
transactionID);
m->setPeerAnnounceStorage(peerAnnounceStorage_);
m->setTokenTracker(tokenTracker_);
m->setBtRegistry(btRegistry_);
m->setFamily(family_);
setCommonProperty(m.get());
return m;
}
std::unique_ptr<DHTGetPeersReplyMessage>
DHTMessageFactoryImpl::createGetPeersReplyMessage(
const std::shared_ptr<DHTNode>& remoteNode, const Dict* dict,
const std::string& transactionID)
{
const Dict* rDict = getDictionary(dict, DHTResponseMessage::R);
const String* nodesData = downcast<String>(
rDict->get(family_ == AF_INET ? DHTGetPeersReplyMessage::NODES
: DHTGetPeersReplyMessage::NODES6));
std::vector<std::shared_ptr<DHTNode>> nodes;
if (nodesData) {
extractNodes(nodes, nodesData->uc(), nodesData->s().size());
}
const List* valuesList =
downcast<List>(rDict->get(DHTGetPeersReplyMessage::VALUES));
std::vector<std::shared_ptr<Peer>> peers;
size_t clen = bittorrent::getCompactLength(family_);
if (valuesList) {
for (auto& elem : *valuesList) {
const String* data = downcast<String>(elem);
if (data && data->s().size() == clen) {
auto addr = bittorrent::unpackcompact(data->uc(), family_);
if (addr.first.empty()) {
continue;
}
peers.push_back(std::make_shared<Peer>(addr.first, addr.second));
}
}
}
const String* token = getString(rDict, DHTGetPeersReplyMessage::TOKEN);
return createGetPeersReplyMessage(remoteNode, std::move(nodes),
std::move(peers), token->s(),
transactionID);
}
std::unique_ptr<DHTGetPeersReplyMessage>
DHTMessageFactoryImpl::createGetPeersReplyMessage(
const std::shared_ptr<DHTNode>& remoteNode,
std::vector<std::shared_ptr<DHTNode>> closestKNodes,
std::vector<std::shared_ptr<Peer>> values, const std::string& token,
const std::string& transactionID)
{
auto m = make_unique<DHTGetPeersReplyMessage>(family_, localNode_, remoteNode,
token, transactionID);
m->setClosestKNodes(std::move(closestKNodes));
m->setValues(std::move(values));
setCommonProperty(m.get());
return m;
}
std::unique_ptr<DHTAnnouncePeerMessage>
DHTMessageFactoryImpl::createAnnouncePeerMessage(
const std::shared_ptr<DHTNode>& remoteNode, const unsigned char* infoHash,
uint16_t tcpPort, const std::string& token,
const std::string& transactionID)
{
auto m = make_unique<DHTAnnouncePeerMessage>(localNode_, remoteNode, infoHash,
tcpPort, token, transactionID);
m->setPeerAnnounceStorage(peerAnnounceStorage_);
m->setTokenTracker(tokenTracker_);
setCommonProperty(m.get());
return m;
}
std::unique_ptr<DHTAnnouncePeerReplyMessage>
DHTMessageFactoryImpl::createAnnouncePeerReplyMessage(
const std::shared_ptr<DHTNode>& remoteNode,
const std::string& transactionID)
{
auto m = make_unique<DHTAnnouncePeerReplyMessage>(localNode_, remoteNode,
transactionID);
setCommonProperty(m.get());
return m;
}
std::unique_ptr<DHTUnknownMessage> DHTMessageFactoryImpl::createUnknownMessage(
const unsigned char* data, size_t length, const std::string& ipaddr,
uint16_t port)
{
return make_unique<DHTUnknownMessage>(localNode_, data, length, ipaddr, port);
}
void DHTMessageFactoryImpl::setRoutingTable(DHTRoutingTable* routingTable)
{
routingTable_ = routingTable;
}
void DHTMessageFactoryImpl::setConnection(DHTConnection* connection)
{
connection_ = connection;
}
void DHTMessageFactoryImpl::setMessageDispatcher(
DHTMessageDispatcher* dispatcher)
{
dispatcher_ = dispatcher;
}
void DHTMessageFactoryImpl::setPeerAnnounceStorage(
DHTPeerAnnounceStorage* storage)
{
peerAnnounceStorage_ = storage;
}
void DHTMessageFactoryImpl::setTokenTracker(DHTTokenTracker* tokenTracker)
{
tokenTracker_ = tokenTracker;
}
void DHTMessageFactoryImpl::setLocalNode(
const std::shared_ptr<DHTNode>& localNode)
{
localNode_ = localNode;
}
void DHTMessageFactoryImpl::setBtRegistry(BtRegistry* btRegistry)
{
btRegistry_ = btRegistry;
}
} // namespace aria2