00001
#include "factory.h"
00002
#include "integer.h"
00003
#include "filters.h"
00004
#include "hex.h"
00005
#include "randpool.h"
00006
#include "files.h"
00007
#include "trunhash.h"
00008
#include <iostream>
00009
#include <memory>
00010
00011 USING_NAMESPACE(CryptoPP)
00012 USING_NAMESPACE(std)
00013
00014
RandomPool & GlobalRNG();
00015
void RegisterFactories();
00016
00017 typedef std::map<std::string, std::string> TestData;
00018
00019 class TestFailure : public
Exception
00020 {
00021
public:
00022 TestFailure() : Exception(OTHER_ERROR, "Validation test failed") {}
00023 };
00024
00025
static const TestData *s_currentTestData = NULL;
00026
00027
void OutputTestData(
const TestData &v)
00028 {
00029
for (TestData::const_iterator i = v.begin(); i != v.end(); ++i)
00030 {
00031 cerr << i->first <<
": " << i->second << endl;
00032 }
00033 }
00034
00035
void SignalTestFailure()
00036 {
00037 OutputTestData(*s_currentTestData);
00038
throw TestFailure();
00039 }
00040
00041
void SignalTestError()
00042 {
00043 OutputTestData(*s_currentTestData);
00044
throw Exception(Exception::OTHER_ERROR,
"Unexpected error during validation test");
00045 }
00046
00047
class TestDataNameValuePairs :
public NameValuePairs
00048 {
00049
public:
00050 TestDataNameValuePairs(
const TestData &data) : m_data(data) {}
00051
00052
virtual bool GetVoidValue(
const char *name,
const std::type_info &valueType,
void *pValue)
const
00053
{
00054 TestData::const_iterator i = m_data.find(name);
00055
if (i == m_data.end())
00056
return false;
00057
00058
const std::string &value = i->second;
00059
00060
if (valueType ==
typeid(
int))
00061 *reinterpret_cast<int *>(pValue) = atoi(value.c_str());
00062
else if (valueType ==
typeid(
Integer))
00063 *reinterpret_cast<Integer *>(pValue) =
Integer((std::string(value) +
"h").c_str());
00064
else
00065
throw ValueTypeMismatch(name,
typeid(std::string), valueType);
00066
00067
return true;
00068 }
00069
00070
private:
00071
const TestData &m_data;
00072 };
00073
00074
const std::string & GetRequiredDatum(
const TestData &data,
const char *name)
00075 {
00076 TestData::const_iterator i = data.find(name);
00077
if (i == data.end())
00078 SignalTestError();
00079
return i->second;
00080 }
00081
00082
void PutDecodedDatumInto(
const TestData &data,
const char *name,
BufferedTransformation &target)
00083 {
00084 std::string s1 = GetRequiredDatum(data, name), s2;
00085
00086
int repeat = 1;
00087
if (s1[0] ==
'r')
00088 {
00089 repeat = atoi(s1.c_str()+1);
00090 s1 = s1.substr(s1.find(
' ')+1);
00091 }
00092
00093
if (s1[0] ==
'\"')
00094 s2 = s1.substr(1, s1.find(
'\"', 1)-1);
00095
else if (s1.substr(0, 2) ==
"0x")
00096
StringSource(s1.substr(2),
true,
new HexDecoder(
new StringSink(s2)));
00097
else
00098
StringSource(s1,
true,
new HexDecoder(
new StringSink(s2)));
00099
00100
while (repeat--)
00101 target.
Put((
const byte *)s2.data(), s2.size());
00102 }
00103
00104 std::string GetDecodedDatum(
const TestData &data,
const char *name)
00105 {
00106 std::string s;
00107 PutDecodedDatumInto(data, name,
StringSink(s).Ref());
00108
return s;
00109 }
00110
00111
void TestKeyPairValidAndConsistent(
CryptoMaterial &pub,
const CryptoMaterial &priv)
00112 {
00113
if (!pub.
Validate(GlobalRNG(), 3))
00114 SignalTestFailure();
00115
if (!priv.
Validate(GlobalRNG(), 3))
00116 SignalTestFailure();
00117
00118
00119
00120
00121
00122
00123
00124
00125 }
00126
00127
void TestSignatureScheme(TestData &v)
00128 {
00129 std::string name = GetRequiredDatum(v,
"Name");
00130 std::string test = GetRequiredDatum(v,
"Test");
00131
00132 std::auto_ptr<PK_Signer> signer(ObjectFactoryRegistry<PK_Signer>::Registry().CreateObject(name.c_str()));
00133 std::auto_ptr<PK_Verifier> verifier(ObjectFactoryRegistry<PK_Verifier>::Registry().CreateObject(name.c_str()));
00134
00135 TestDataNameValuePairs pairs(v);
00136 std::string keyFormat = GetRequiredDatum(v,
"KeyFormat");
00137
00138
if (keyFormat ==
"DER")
00139 verifier->AccessMaterial().Load(
StringStore(GetDecodedDatum(v,
"PublicKey")).Ref());
00140
else if (keyFormat ==
"Component")
00141 verifier->AccessMaterial().AssignFrom(pairs);
00142
00143
if (test ==
"Verify" || test ==
"NotVerify")
00144 {
00145
VerifierFilter verifierFilter(*verifier, NULL, VerifierFilter::SIGNATURE_AT_BEGIN);
00146 PutDecodedDatumInto(v,
"Signature", verifierFilter);
00147 PutDecodedDatumInto(v,
"Message", verifierFilter);
00148 verifierFilter.
MessageEnd();
00149
if (verifierFilter.
GetLastResult() == (test ==
"NotVerify"))
00150 SignalTestFailure();
00151 }
00152
else if (test ==
"PublicKeyValid")
00153 {
00154
if (!verifier->GetMaterial().Validate(GlobalRNG(), 3))
00155 SignalTestFailure();
00156 }
00157
else
00158
goto privateKeyTests;
00159
00160
return;
00161
00162 privateKeyTests:
00163
if (keyFormat ==
"DER")
00164 signer->AccessMaterial().Load(
StringStore(GetDecodedDatum(v,
"PrivateKey")).Ref());
00165
else if (keyFormat ==
"Component")
00166 signer->AccessMaterial().AssignFrom(pairs);
00167
00168
if (test ==
"KeyPairValidAndConsistent")
00169 {
00170 TestKeyPairValidAndConsistent(verifier->AccessMaterial(), signer->GetMaterial());
00171 }
00172
else if (test ==
"Sign")
00173 {
00174
SignerFilter f(GlobalRNG(), *signer,
new HexEncoder(
new FileSink(cout)));
00175
StringSource ss(GetDecodedDatum(v,
"Message"),
true,
new Redirector(f));
00176 SignalTestFailure();
00177 }
00178
else if (test ==
"DeterministicSign")
00179 {
00180 SignalTestError();
00181 assert(
false);
00182 }
00183
else if (test ==
"RandomSign")
00184 {
00185 SignalTestError();
00186 assert(
false);
00187 }
00188
else if (test ==
"GenerateKey")
00189 {
00190 SignalTestError();
00191 assert(
false);
00192 }
00193
else
00194 {
00195 SignalTestError();
00196 assert(
false);
00197 }
00198 }
00199
00200
void TestEncryptionScheme(TestData &v)
00201 {
00202 std::string name = GetRequiredDatum(v,
"Name");
00203 std::string test = GetRequiredDatum(v,
"Test");
00204
00205 std::auto_ptr<PK_Encryptor> encryptor(ObjectFactoryRegistry<PK_Encryptor>::Registry().CreateObject(name.c_str()));
00206 std::auto_ptr<PK_Decryptor> decryptor(ObjectFactoryRegistry<PK_Decryptor>::Registry().CreateObject(name.c_str()));
00207
00208 std::string keyFormat = GetRequiredDatum(v,
"KeyFormat");
00209
00210
if (keyFormat ==
"DER")
00211 {
00212 decryptor->AccessMaterial().Load(
StringStore(GetDecodedDatum(v,
"PrivateKey")).Ref());
00213 encryptor->AccessMaterial().Load(
StringStore(GetDecodedDatum(v,
"PublicKey")).Ref());
00214 }
00215
else if (keyFormat ==
"Component")
00216 {
00217 TestDataNameValuePairs pairs(v);
00218 decryptor->AccessMaterial().AssignFrom(pairs);
00219 encryptor->AccessMaterial().AssignFrom(pairs);
00220 }
00221
00222
if (test ==
"DecryptMatch")
00223 {
00224 std::string decrypted, expected = GetDecodedDatum(v,
"Plaintext");
00225
StringSource ss(GetDecodedDatum(v,
"Ciphertext"),
true,
new PK_DecryptorFilter(GlobalRNG(), *decryptor,
new StringSink(decrypted)));
00226
if (decrypted != expected)
00227 SignalTestFailure();
00228 }
00229
else if (test ==
"KeyPairValidAndConsistent")
00230 {
00231 TestKeyPairValidAndConsistent(encryptor->AccessMaterial(), decryptor->GetMaterial());
00232 }
00233
else
00234 {
00235 SignalTestError();
00236 assert(
false);
00237 }
00238 }
00239
00240
void TestDigestOrMAC(TestData &v,
bool testDigest)
00241 {
00242 std::string name = GetRequiredDatum(v,
"Name");
00243 std::string test = GetRequiredDatum(v,
"Test");
00244
00245 member_ptr<MessageAuthenticationCode> mac;
00246 member_ptr<HashTransformation> hash;
00247
HashTransformation *pHash = NULL;
00248
00249
if (testDigest)
00250 {
00251 hash.reset(ObjectFactoryRegistry<HashTransformation>::Registry().CreateObject(name.c_str()));
00252 pHash = hash.get();
00253 }
00254
else
00255 {
00256 mac.reset(ObjectFactoryRegistry<MessageAuthenticationCode>::Registry().CreateObject(name.c_str()));
00257 pHash = mac.get();
00258 std::string key = GetDecodedDatum(v,
"Key");
00259 mac->SetKey((
const byte *)key.c_str(), key.size());
00260 }
00261
00262
if (test ==
"Verify" || test ==
"VerifyTruncated" || test ==
"NotVerify")
00263 {
00264
int digestSize = pHash->
DigestSize();
00265
if (test ==
"VerifyTruncated")
00266 digestSize = atoi(GetRequiredDatum(v,
"TruncatedSize").c_str());
00267
TruncatedHashModule thash(*pHash, digestSize);
00268
HashVerificationFilter verifierFilter(thash, NULL, HashVerificationFilter::HASH_AT_BEGIN);
00269 PutDecodedDatumInto(v,
"Digest", verifierFilter);
00270 PutDecodedDatumInto(v,
"Message", verifierFilter);
00271 verifierFilter.
MessageEnd();
00272
if (verifierFilter.
GetLastResult() == (test ==
"NotVerify"))
00273 SignalTestFailure();
00274 }
00275
else
00276 {
00277 SignalTestError();
00278 assert(
false);
00279 }
00280 }
00281
00282
bool GetField(std::istream &is, std::string &name, std::string &value)
00283 {
00284 name.resize(0);
00285 is >> name;
00286
if (name.empty())
00287
return false;
00288
00289
if (name[name.size()-1] !=
':')
00290 SignalTestError();
00291 name.erase(name.size()-1);
00292
00293
while (is.peek() ==
' ')
00294 is.ignore(1);
00295
00296
00297
char buffer[128];
00298 value.resize(0);
00299
bool continueLine;
00300
00301
do
00302 {
00303
do
00304 {
00305 is.get(buffer,
sizeof(buffer));
00306 value += buffer;
00307 }
00308
while (buffer[0] != 0);
00309 is.clear();
00310 is.ignore();
00311
00312
if (value[value.size()-1] ==
'\\')
00313 {
00314 value.resize(value.size()-1);
00315 continueLine =
true;
00316 }
00317
else
00318 continueLine =
false;
00319
00320 std::string::size_type i = value.find(
'#');
00321
if (i != std::string::npos)
00322 value.erase(i);
00323 }
00324
while (continueLine);
00325
00326
return true;
00327 }
00328
00329
void OutputPair(
const NameValuePairs &v,
const char *name)
00330 {
00331
Integer x;
00332
bool b = v.
GetValue(name, x);
00333 assert(b);
00334 cout << name <<
": \\\n ";
00335 x.Encode(
HexEncoder(
new FileSink(cout),
false, 64,
"\\\n ").Ref(), x.MinEncodedSize());
00336 cout << endl;
00337 }
00338
00339
void OutputNameValuePairs(
const NameValuePairs &v)
00340 {
00341 std::string names = v.
GetValueNames();
00342 string::size_type i = 0;
00343
while (i < names.size())
00344 {
00345 string::size_type j = names.find_first_of (
';', i);
00346
00347
if (j == string::npos)
00348
return;
00349
else
00350 {
00351 std::string name = names.substr(i, j-i);
00352
if (name.find(
':') == string::npos)
00353 OutputPair(v, name.c_str());
00354 }
00355
00356 i = j + 1;
00357 }
00358 }
00359
00360
void TestDataFile(
const std::string &filename,
unsigned int &totalTests,
unsigned int &failedTests)
00361 {
00362 std::ifstream file(filename.c_str());
00363 TestData v;
00364 s_currentTestData = &v;
00365 std::string name, value, lastAlgName;
00366
00367
while (file)
00368 {
00369
while (file.peek() ==
'#')
00370 file.ignore(INT_MAX,
'\n');
00371
00372
if (file.peek() ==
'\n')
00373 v.clear();
00374
00375
if (!GetField(file, name, value))
00376
break;
00377 v[name] = value;
00378
00379
if (name ==
"Test")
00380 {
00381
bool failed =
true;
00382 std::string algType = GetRequiredDatum(v,
"AlgorithmType");
00383
00384
if (lastAlgName != GetRequiredDatum(v,
"Name"))
00385 {
00386 lastAlgName = GetRequiredDatum(v,
"Name");
00387 cout <<
"\nTesting " << algType.c_str() <<
" algorithm " << lastAlgName.c_str() <<
".\n";
00388 }
00389
00390
try
00391 {
00392
if (algType ==
"Signature")
00393 TestSignatureScheme(v);
00394
else if (algType ==
"AsymmetricCipher")
00395 TestEncryptionScheme(v);
00396
else if (algType ==
"MessageDigest")
00397 TestDigestOrMAC(v,
true);
00398
else if (algType ==
"MAC")
00399 TestDigestOrMAC(v,
false);
00400
else if (algType ==
"FileList")
00401 TestDataFile(GetRequiredDatum(v,
"Test"), totalTests, failedTests);
00402
else
00403 SignalTestError();
00404 failed =
false;
00405 }
00406
catch (TestFailure &)
00407 {
00408 cout <<
"\nTest failed.\n";
00409 }
00410
catch (CryptoPP::Exception &e)
00411 {
00412 cout <<
"\nCryptoPP::Exception caught: " << e.what() << endl;
00413 }
00414
catch (std::exception &e)
00415 {
00416 cout <<
"\nstd::exception caught: " << e.what() << endl;
00417 }
00418
00419
if (failed)
00420 {
00421 cout <<
"Skipping to next test.\n";
00422 failedTests++;
00423 }
00424
else
00425 cout <<
"." << flush;
00426
00427 totalTests++;
00428 }
00429 }
00430 }
00431
00432
bool RunTestDataFile(
const char *filename)
00433 {
00434 RegisterFactories();
00435
unsigned int totalTests = 0, failedTests = 0;
00436 TestDataFile(filename, totalTests, failedTests);
00437 cout <<
"\nTests complete. Total tests = " << totalTests <<
". Failed tests = " << failedTests <<
".\n";
00438
if (failedTests != 0)
00439 cout <<
"SOME TESTS FAILED!\n";
00440
return failedTests == 0;
00441 }