diff --git a/configure.ac b/configure.ac index 2f8279307..8cf19c1e4 100644 --- a/configure.ac +++ b/configure.ac @@ -40,6 +40,7 @@ AC_DEFINE_UNQUOTED([TARGET], ["$target"], [Define target-type]) # Checks for arguments. ARIA2_ARG_WITHOUT([libuv]) ARIA2_ARG_WITHOUT([appletls]) +ARIA2_ARG_WITHOUT([wintls]) ARIA2_ARG_WITHOUT([gnutls]) ARIA2_ARG_WITHOUT([libnettle]) ARIA2_ARG_WITHOUT([libgmp]) @@ -286,8 +287,30 @@ case "$host" in *darwin*) have_osx="yes" ;; + *mingw*) + AC_CHECK_HEADERS([windows.h \ + winsock2.h \ + ws2tcpip.h \ + mmsystem.h \ + io.h \ + iphlpapi.h\ + winioctl.h \ + share.h], [], [], + [[ +#ifdef HAVE_WS2TCPIP_H +# include +#endif +#ifdef HAVE_WINSOCK2_H +# include +#endif +#ifdef HAVE_WINDOWS_H +# include +#endif + ]]) + ;; esac + if test "x$with_appletls" = "xyes"; then AC_MSG_CHECKING([whether to enable Mac OS X native SSL/TLS]) if test "x$have_osx" = "xyes"; then @@ -303,6 +326,23 @@ if test "x$with_appletls" = "xyes"; then fi fi +if test "x$with_wintls" = "xyes"; then + AC_SEARCH_LIBS([CryptAcquireContextW], [advapi32], [ + AC_CHECK_HEADER([wincrypt.h], [have_wincrypt=yes], [have_wincrypt=no], + [[ +#ifdef HAVE_WINDOWS_H +# include +#endif + ]]) + break; + ], [have_wincrypt=no]) + if test "x$have_wincrypt" != "xyes"; then + if test "x$with_wintls_requested" = "xyes"; then + ARIA2_DEP_NOT_MET([wintls]) + fi + fi +fi + if test "x$with_gnutls" = "xyes" && test "x$have_appletls" != "xyes"; then # gnutls >= 2.8 doesn't have libgnutls-config anymore. We require # 2.2.0 because we use gnutls_priority_set_direct() @@ -398,17 +438,22 @@ if test "x$have_osx" == "xyes"; then use_md="apple" AC_DEFINE([USE_APPLE_MD], [1], [What message digest implementation to use]) else - if test "x$have_libnettle" = "xyes"; then - AC_DEFINE([USE_LIBNETTLE_MD], [1], [What message digest implementation to use]) - use_md="libnettle" + if test "x$have_wincrypt" == "xyes"; then + use_md="windows" + AC_DEFINE([USE_WINDOWS_MD], [1], [What message digest implementation to use]) else - if test "x$have_libgcrypt" = "xyes"; then - AC_DEFINE([USE_LIBGCRYPT_MD], [1], [What message digest implementation to use]) - use_md="libgcrypt" + if test "x$have_libnettle" = "xyes"; then + AC_DEFINE([USE_LIBNETTLE_MD], [1], [What message digest implementation to use]) + use_md="libnettle" else - if test "x$have_openssl" = "xyes"; then - AC_DEFINE([USE_OPENSSL_MD], [1], [What message digest implementation to use]) - use_md="openssl" + if test "x$have_libgcrypt" = "xyes"; then + AC_DEFINE([USE_LIBGCRYPT_MD], [1], [What message digest implementation to use]) + use_md="libgcrypt" + else + if test "x$have_openssl" = "xyes"; then + AC_DEFINE([USE_OPENSSL_MD], [1], [What message digest implementation to use]) + use_md="openssl" + fi fi fi fi @@ -427,6 +472,7 @@ fi AM_CONDITIONAL([HAVE_OSX], [ test "x$have_osx" = "xyes" ]) AM_CONDITIONAL([HAVE_APPLETLS], [ test "x$have_appletls" = "xyes" ]) AM_CONDITIONAL([USE_APPLE_MD], [ test "x$use_md" = "xapple" ]) +AM_CONDITIONAL([USE_WINDOWS_MD], [ test "x$use_md" = "xwindows" ]) AM_CONDITIONAL([HAVE_LIBGNUTLS], [ test "x$have_libgnutls" = "xyes" ]) AM_CONDITIONAL([HAVE_LIBNETTLE], [ test "x$have_libnettle" = "xyes" ]) AM_CONDITIONAL([USE_LIBNETTLE_MD], [ test "x$use_md" = "xlibnettle"]) @@ -519,30 +565,6 @@ esac AC_FUNC_ALLOCA AC_HEADER_STDC -case "$host" in - *mingw*) - AC_CHECK_HEADERS([windows.h \ - winsock2.h \ - ws2tcpip.h \ - mmsystem.h \ - io.h \ - iphlpapi.h\ - winioctl.h \ - share.h], [], [], - [[ -#ifdef HAVE_WS2TCPIP_H -# include -#endif -#ifdef HAVE_WINSOCK2_H -# include -#endif -#ifdef HAVE_WINDOWS_H -# include -#endif - ]]) - ;; -esac - AC_CHECK_HEADERS([argz.h \ arpa/inet.h \ fcntl.h \ diff --git a/src/Makefile.am b/src/Makefile.am index 01532752f..c36521811 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -329,6 +329,10 @@ SRCS += AppleTLSContext.cc AppleTLSContext.h \ AppleTLSSession.cc AppleTLSSession.h endif # HAVE_APPLETLS +if USE_WINDOWS_MD +SRCS += WinMessageDigestImpl.cc WinMessageDigestImpl.h +endif # USE_WINDOWS_MD + if HAVE_LIBGNUTLS SRCS += LibgnutlsTLSContext.cc LibgnutlsTLSContext.h \ LibgnutlsTLSSession.cc LibgnutlsTLSSession.h diff --git a/src/MessageDigestImpl.h b/src/MessageDigestImpl.h index e7de69bbc..aee3057f3 100644 --- a/src/MessageDigestImpl.h +++ b/src/MessageDigestImpl.h @@ -38,6 +38,8 @@ #ifdef USE_APPLE_MD # include "AppleMessageDigestImpl.h" +#elif defined(USE_WINDOWS_MD) +# include "WinMessageDigestImpl.h" #elif defined(USE_LIBNETTLE_MD) # include "LibnettleMessageDigestImpl.h" #elif defined(USE_LIBGCRYPT_MD) diff --git a/src/WinMessageDigestImpl.cc b/src/WinMessageDigestImpl.cc new file mode 100644 index 000000000..bbc6d8973 --- /dev/null +++ b/src/WinMessageDigestImpl.cc @@ -0,0 +1,177 @@ +/* */ + +#include "WinMessageDigestImpl.h" + +#include + +#include "array_fun.h" +#include "a2functional.h" +#include "HashFuncEntry.h" +#include "DlAbortEx.h" + +namespace { +using namespace aria2; + +class Context { +private: + HCRYPTPROV provider_; +public: + Context() { + if (!::CryptAcquireContext(&provider_, nullptr, nullptr, PROV_RSA_FULL, + CRYPT_VERIFYCONTEXT)) { + throw DL_ABORT_EX("Failed to get cryptographic provider"); + } + } + ~Context() { + ::CryptReleaseContext(provider_, 0); + } + + HCRYPTPROV get() { + return provider_; + } +}; + +// XXX static OK? +static Context context_; + +} // namespace + +namespace aria2 { + +template +class MessageDigestBase : public MessageDigestImpl { +private: + HCRYPTHASH hash_; + DWORD len_; + + void destroy() { + if (hash_) { + ::CryptDestroyHash(hash_); + hash_ = 0; + } + } + +public: + MessageDigestBase() : hash_(0), len_(0) { reset(); } + virtual ~MessageDigestBase() { destroy(); } + + virtual size_t getDigestLength() const CXX11_OVERRIDE { + return len_; + } + virtual void reset() CXX11_OVERRIDE { + destroy(); + if (!::CryptCreateHash(context_.get(), id, 0, 0, &hash_)) { + throw DL_ABORT_EX("Failed to create hash"); + } + + DWORD len = sizeof(len_); + if (!::CryptGetHashParam(hash_, HP_HASHSIZE, reinterpret_cast(&len_), + &len, 0)) { + throw DL_ABORT_EX("Failed to create hash"); + } + } + virtual void update(const void* data, size_t length) CXX11_OVERRIDE { + auto bytes = reinterpret_cast(data); + while (length) { + DWORD l = std::min(length, (size_t)std::numeric_limits::max()); + if (!::CryptHashData(hash_, bytes, l, 0)) { + throw DL_ABORT_EX("Failed to update hash"); + } + length -= l; + bytes += l; + } + } + virtual void digest(unsigned char* md) CXX11_OVERRIDE { + DWORD len = len_; + if (!::CryptGetHashParam(hash_, HP_HASHVAL, md, &len, 0)) { + throw DL_ABORT_EX("Failed to create hash digest"); + } + } +}; + +typedef MessageDigestBase MessageDigestMD5; +typedef MessageDigestBase MessageDigestSHA1; +typedef MessageDigestBase MessageDigestSHA256; +typedef MessageDigestBase MessageDigestSHA384; +typedef MessageDigestBase MessageDigestSHA512; + +std::unique_ptr MessageDigestImpl::sha1() +{ + return std::unique_ptr(new MessageDigestSHA1()); +} + +std::unique_ptr MessageDigestImpl::create( + const std::string& hashType) +{ + if (hashType == "sha-1") { + return make_unique(); + } + if (hashType == "sha-256") { + return make_unique(); + } + if (hashType == "sha-384") { + return make_unique(); + } + if (hashType == "sha-512") { + return make_unique(); + } + if (hashType == "md5") { + return make_unique(); + } + return nullptr; +} + +bool MessageDigestImpl::supports(const std::string& hashType) +{ + try { + return !!create(hashType); + } + catch (RecoverableException& ex) { + // no op + } + return false; +} + +size_t MessageDigestImpl::getDigestLength(const std::string& hashType) +{ + std::unique_ptr impl = create(hashType); + if (!impl) { + return 0; + } + return impl->getDigestLength(); +} + +} // namespace aria2 diff --git a/src/WinMessageDigestImpl.h b/src/WinMessageDigestImpl.h new file mode 100644 index 000000000..758937a78 --- /dev/null +++ b/src/WinMessageDigestImpl.h @@ -0,0 +1,70 @@ +/* */ +#ifndef D_WIN_MESSAGE_DIGEST_IMPL_H +#define D_WIN_MESSAGE_DIGEST_IMPL_H + +#include "common.h" + +#include +#include + +namespace aria2 { + +class MessageDigestImpl { +public: + virtual ~MessageDigestImpl() {} + static std::unique_ptr sha1(); + static std::unique_ptr create(const std::string& hashType); + + static bool supports(const std::string& hashType); + static size_t getDigestLength(const std::string& hashType); + +public: + virtual size_t getDigestLength() const = 0; + virtual void reset() = 0; + virtual void update(const void* data, size_t length) = 0; + virtual void digest(unsigned char* md) = 0; + +protected: + MessageDigestImpl() {} + +private: + MessageDigestImpl(const MessageDigestImpl&); + MessageDigestImpl& operator=(const MessageDigestImpl&); +}; + +} // namespace aria2 + +#endif // D_WIN_MESSAGE_DIGEST_IMPL_H