diff --git a/src/DefaultBtInteractive.cc b/src/DefaultBtInteractive.cc index c74d8e7a0..4daae189d 100644 --- a/src/DefaultBtInteractive.cc +++ b/src/DefaultBtInteractive.cc @@ -189,6 +189,7 @@ void DefaultBtInteractive::doPostHandshakeProcessing() if (!metadataGetMode_) { addAllowedFastMessageToQueue(); } + peerStorage_->scheduleForcedChokeRound(); sendPendingMessage(); } diff --git a/src/DefaultPeerStorage.cc b/src/DefaultPeerStorage.cc index 60e13a725..6047462a6 100644 --- a/src/DefaultPeerStorage.cc +++ b/src/DefaultPeerStorage.cc @@ -62,7 +62,8 @@ DefaultPeerStorage::DefaultPeerStorage() : maxPeerListSize_(MAX_PEER_LIST_SIZE), seederStateChoke_(make_unique()), leecherStateChoke_(make_unique()), - lastTransferStatMapUpdated_(Timer::zero()) + lastTransferStatMapUpdated_(Timer::zero()), + forceChokeRound_(false) { } @@ -284,16 +285,21 @@ void DefaultPeerStorage::returnPeer(const std::shared_ptr& peer) bool DefaultPeerStorage::chokeRoundIntervalElapsed() { constexpr auto CHOKE_ROUND_INTERVAL = 10_s; + auto forceChokeRound = forceChokeRound_; + forceChokeRound_ = false; + if (pieceStorage_->downloadFinished()) { + auto interval = forceChokeRound ? 1_s : CHOKE_ROUND_INTERVAL; return seederStateChoke_->getLastRound().difference(global::wallclock()) >= - CHOKE_ROUND_INTERVAL; - } - else { - return leecherStateChoke_->getLastRound().difference(global::wallclock()) >= - CHOKE_ROUND_INTERVAL; + interval; } + + return leecherStateChoke_->getLastRound().difference(global::wallclock()) >= + CHOKE_ROUND_INTERVAL; } +void DefaultPeerStorage::scheduleForcedChokeRound() { forceChokeRound_ = true; } + void DefaultPeerStorage::executeChoke() { if (pieceStorage_->downloadFinished()) { diff --git a/src/DefaultPeerStorage.h b/src/DefaultPeerStorage.h index 3fd60a91a..f2714f827 100644 --- a/src/DefaultPeerStorage.h +++ b/src/DefaultPeerStorage.h @@ -74,6 +74,8 @@ private: std::map badPeers_; Timer lastBadPeerCleaned_; + bool forceChokeRound_; + bool isPeerAlreadyAdded(const std::shared_ptr& peer); void addUniqPeer(const std::shared_ptr& peer); @@ -113,6 +115,8 @@ public: virtual bool chokeRoundIntervalElapsed() CXX11_OVERRIDE; + virtual void scheduleForcedChokeRound() CXX11_OVERRIDE; + virtual void executeChoke() CXX11_OVERRIDE; void deleteUnusedPeer(size_t delSize); diff --git a/src/PeerInteractionCommand.cc b/src/PeerInteractionCommand.cc index 9d2898bc3..d2eec4242 100644 --- a/src/PeerInteractionCommand.cc +++ b/src/PeerInteractionCommand.cc @@ -347,29 +347,23 @@ bool PeerInteractionCommand::executeInternal() if (btInteractive_->countReceivedMessageInIteration() > 0) { updateKeepAlive(); } - if ((getPeer()->amInterested() && !getPeer()->peerChoking()) || - btInteractive_->countOutstandingRequest() || - (getPeer()->peerInterested() && !getPeer()->amChoking())) { - // Writable check to avoid slow seeding - if (btInteractive_->isSendingMessageInProgress()) { - setWriteCheckSocket(getSocket()); - } + // Writable check to avoid slow seeding + if (btInteractive_->isSendingMessageInProgress()) { + setWriteCheckSocket(getSocket()); + } - if (getDownloadEngine() - ->getRequestGroupMan() - ->doesOverallDownloadSpeedExceed() || - requestGroup_->doesDownloadSpeedExceed()) { - disableReadCheckSocket(); - setNoCheck(true); - } - else { - setReadCheckSocket(getSocket()); - } + if (getDownloadEngine() + ->getRequestGroupMan() + ->doesOverallDownloadSpeedExceed() || + requestGroup_->doesDownloadSpeedExceed()) { + disableReadCheckSocket(); + setNoCheck(true); } else { - disableReadCheckSocket(); + setReadCheckSocket(getSocket()); } + done = true; break; } diff --git a/src/PeerStorage.h b/src/PeerStorage.h index 08132b062..da546c06c 100644 --- a/src/PeerStorage.h +++ b/src/PeerStorage.h @@ -113,6 +113,11 @@ public: virtual bool chokeRoundIntervalElapsed() = 0; + /** + * Schedules choke round forcibly. + */ + virtual void scheduleForcedChokeRound() = 0; + virtual void executeChoke() = 0; }; diff --git a/test/MockPeerStorage.h b/test/MockPeerStorage.h index 05781823b..f20e896c5 100644 --- a/test/MockPeerStorage.h +++ b/test/MockPeerStorage.h @@ -84,6 +84,8 @@ public: virtual bool chokeRoundIntervalElapsed() CXX11_OVERRIDE { return false; } + virtual void scheduleForcedChokeRound() CXX11_OVERRIDE {} + virtual void executeChoke() CXX11_OVERRIDE { ++numChokeExecuted_; } int getNumChokeExecuted() const { return numChokeExecuted_; }