diff --git a/.travis.yml b/.travis.yml index dfe577ce..d570e593 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,8 +1,8 @@ language: python -cache: pip python: - 2.7 install: + - pip install -U pip wheel - pip install -r requirements.txt before_script: - openssl version -a @@ -16,3 +16,6 @@ before_install: after_success: - codecov - coveralls --rcfile=src/Test/coverage.ini +cache: + directories: + - $HOME/.cache/pip \ No newline at end of file diff --git a/plugins/AnnounceZero/AnnounceZeroPlugin.py b/plugins/AnnounceZero/AnnounceZeroPlugin.py new file mode 100644 index 00000000..de80aae5 --- /dev/null +++ b/plugins/AnnounceZero/AnnounceZeroPlugin.py @@ -0,0 +1,118 @@ +import hashlib +import time + +from Plugin import PluginManager +from Peer import Peer +from util import helper +from Crypt import CryptRsa + +allow_reload = False # No source reload supported in this plugin +time_full_announced = {} # Tracker address: Last announced all site to tracker +connection_pool = {} # Tracker address: Peer object + + +# Process result got back from tracker +def processPeerRes(site, peers): + added = 0 + # Ip4 + found_ip4 = 0 + for packed_address in peers["ip4"]: + found_ip4 += 1 + peer_ip, peer_port = helper.unpackAddress(packed_address) + if site.addPeer(peer_ip, peer_port): + added += 1 + # Onion + found_onion = 0 + for packed_address in peers["onion"]: + found_onion += 1 + peer_onion, peer_port = helper.unpackOnionAddress(packed_address) + if site.addPeer(peer_onion, peer_port): + added += 1 + + if added: + site.worker_manager.onPeers() + site.updateWebsocket(peers_added=added) + site.log.debug("Found %s ip4, %s onion peers, new: %s" % (found_ip4, found_onion, added)) + + +@PluginManager.registerTo("Site") +class SitePlugin(object): + def announceTracker(self, tracker_protocol, tracker_address, fileserver_port=0, add_types=[], my_peer_id="", mode="start"): + if tracker_protocol != "zero": + return super(SitePlugin, self).announceTracker( + tracker_protocol, tracker_address, fileserver_port, add_types, my_peer_id, mode + ) + + s = time.time() + + need_types = ["ip4"] + if self.connection_server and self.connection_server.tor_manager.enabled: + need_types.append("onion") + + if mode == "start" or mode == "more": # Single: Announce only this site + sites = [self] + full_announce = False + else: # Multi: Announce all currently serving site + full_announce = True + if time.time() - time_full_announced.get(tracker_address, 0) < 60 * 5: # No reannounce all sites within 5 minute + return True + time_full_announced[tracker_address] = time.time() + from Site import SiteManager + sites = [site for site in SiteManager.site_manager.sites.values() if site.settings["serving"]] + + # Create request + request = { + "hashes": [], "onions": [], "port": fileserver_port, "need_types": need_types, "need_num": 20, "add": add_types + } + for site in sites: + if "onion" in add_types: + onion = self.connection_server.tor_manager.getOnion(site.address) + request["onions"].append(onion) + request["hashes"].append(hashlib.sha256(site.address).digest()) + + # Tracker can remove sites that we don't announce + if full_announce: + request["delete"] = True + + # Sent request to tracker + tracker = connection_pool.get(tracker_address) # Re-use tracker connection if possible + if not tracker: + tracker_ip, tracker_port = tracker_address.split(":") + tracker = Peer(tracker_ip, tracker_port, connection_server=self.connection_server) + connection_pool[tracker_address] = tracker + res = tracker.request("announce", request) + + if not res or "peers" not in res: + self.log.debug("Announce to %s failed: %s" % (tracker_address, res)) + if full_announce: + time_full_announced[tracker_address] = 0 + return False + + # Add peers from response to site + site_index = 0 + for site_res in res["peers"]: + site = sites[site_index] + processPeerRes(site, site_res) + site_index += 1 + + # Check if we need to sign prove the onion addresses + if "onion_sign_this" in res: + self.log.debug("Signing %s for %s to add %s onions" % (res["onion_sign_this"], tracker_address, len(sites))) + request["onion_signs"] = {} + request["onion_sign_this"] = res["onion_sign_this"] + request["need_num"] = 0 + for site in sites: + onion = self.connection_server.tor_manager.getOnion(site.address) + sign = CryptRsa.sign(res["onion_sign_this"], self.connection_server.tor_manager.getPrivatekey(onion)) + request["onion_signs"][self.connection_server.tor_manager.getPublickey(onion)] = sign + res = tracker.request("announce", request) + if not res or "onion_sign_this" in res: + self.log.debug("Announce onion address to %s failed: %s" % (tracker_address, res)) + if full_announce: + time_full_announced[tracker_address] = 0 + return False + + if full_announce: + tracker.remove() # Close connection, we don't need it in next 5 minute + + return time.time() - s diff --git a/plugins/AnnounceZero/__init__.py b/plugins/AnnounceZero/__init__.py new file mode 100644 index 00000000..4b9cbe10 --- /dev/null +++ b/plugins/AnnounceZero/__init__.py @@ -0,0 +1 @@ +import AnnounceZeroPlugin \ No newline at end of file diff --git a/plugins/Sidebar/SidebarPlugin.py b/plugins/Sidebar/SidebarPlugin.py index ab258822..21269935 100644 --- a/plugins/Sidebar/SidebarPlugin.py +++ b/plugins/Sidebar/SidebarPlugin.py @@ -60,6 +60,7 @@ class UiWebsocketPlugin(object): def sidebarRenderPeerStats(self, body, site): connected = len([peer for peer in site.peers.values() if peer.connection and peer.connection.connected]) connectable = len([peer_id for peer_id in site.peers.keys() if not peer_id.endswith(":0")]) + onion = len([peer_id for peer_id in site.peers.keys() if ".onion" in peer_id]) peers_total = len(site.peers) if peers_total: percent_connected = float(connected) / peers_total @@ -77,6 +78,7 @@ class UiWebsocketPlugin(object): @@ -201,7 +203,6 @@ class UiWebsocketPlugin(object): """.format(**locals())) - def sidebarRenderOptionalFileStats(self, body, site): size_total = 0.0 size_downloaded = 0.0 @@ -213,7 +214,6 @@ class UiWebsocketPlugin(object): if site.content_manager.hashfield.hasHash(file_details["sha512"]): size_downloaded += file_details["size"] - if not size_total: return False @@ -365,30 +365,43 @@ class UiWebsocketPlugin(object): import urllib import gzip import shutil + from util import helper self.log.info("Downloading GeoLite2 City database...") self.cmd("notification", ["geolite-info", "Downloading GeoLite2 City database (one time only, ~15MB)...", 0]) - try: - # Download - file = urllib.urlopen("http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz") - data = StringIO.StringIO() - while True: - buff = file.read(1024 * 16) - if not buff: - break - data.write(buff) - self.log.info("GeoLite2 City database downloaded (%s bytes), unpacking..." % data.tell()) - data.seek(0) + db_urls = [ + "http://geolite.maxmind.com/download/geoip/database/GeoLite2-City.mmdb.gz", + "https://raw.githubusercontent.com/texnikru/GeoLite2-Database/master/GeoLite2-City.mmdb.gz" + ] + for db_url in db_urls: + try: + # Download + response = helper.httpRequest(db_url) - # Unpack - with gzip.GzipFile(fileobj=data) as gzip_file: - shutil.copyfileobj(gzip_file, open(db_path, "wb")) + data = StringIO.StringIO() + while True: + buff = response.read(1024 * 512) + if not buff: + break + data.write(buff) + self.log.info("GeoLite2 City database downloaded (%s bytes), unpacking..." % data.tell()) + data.seek(0) - self.cmd("notification", ["geolite-done", "GeoLite2 City database downloaded!", 5000]) - time.sleep(2) # Wait for notify animation - except Exception, err: - self.cmd("notification", ["geolite-error", "GeoLite2 City database download error: %s!" % err, 0]) - raise err + # Unpack + with gzip.GzipFile(fileobj=data) as gzip_file: + shutil.copyfileobj(gzip_file, open(db_path, "wb")) + + self.cmd("notification", ["geolite-done", "GeoLite2 City database downloaded!", 5000]) + time.sleep(2) # Wait for notify animation + return True + except Exception, err: + self.log.error("Error downloading %s: %s" % (db_url, err)) + pass + self.cmd("notification", [ + "geolite-error", + "GeoLite2 City database download error: %s!
Please download and unpack to data dir:
%s" % (err, db_urls[0]), + 0 + ]) def actionSidebarGetPeers(self, to): permissions = self.getPermissions(to) @@ -397,8 +410,9 @@ class UiWebsocketPlugin(object): try: import maxminddb db_path = config.data_dir + '/GeoLite2-City.mmdb' - if not os.path.isfile(db_path): - self.downloadGeoLiteDb(db_path) + if not os.path.isfile(db_path) or os.path.getsize(db_path) == 0: + if not self.downloadGeoLiteDb(db_path): + return False geodb = maxminddb.open_database(db_path) peers = self.site.peers.values() @@ -426,7 +440,10 @@ class UiWebsocketPlugin(object): if peer.ip in loc_cache: loc = loc_cache[peer.ip] else: - loc = geodb.get(peer.ip) + try: + loc = geodb.get(peer.ip) + except: + loc = None loc_cache[peer.ip] = loc if not loc or "location" not in loc: continue @@ -458,7 +475,6 @@ class UiWebsocketPlugin(object): return self.response(to, "You don't have permission to run this command") self.site.settings["own"] = bool(owned) - def actionSiteSetAutodownloadoptional(self, to, owned): permissions = self.getPermissions(to) if "ADMIN" not in permissions: diff --git a/plugins/Stats/StatsPlugin.py b/plugins/Stats/StatsPlugin.py index cef76c70..38ea924f 100644 --- a/plugins/Stats/StatsPlugin.py +++ b/plugins/Stats/StatsPlugin.py @@ -53,6 +53,7 @@ class UiRequestPlugin(object): """ @@ -113,15 +114,20 @@ class UiRequestPlugin(object): ]) yield "" + # Tor hidden services + yield "

Tor hidden services (status: %s):
" % main.file_server.tor_manager.status + for site_address, onion in main.file_server.tor_manager.site_onions.items(): + yield "- %-34s: %s
" % (site_address, onion) + # Sites yield "

Sites:" yield "" yield "" - for site in self.server.sites.values(): + for site in sorted(self.server.sites.values(), lambda a, b: cmp(a.address,b.address)): yield self.formatTableRow([ ( - """%s""", - (site.address, site.address) + """%s""", + (site.settings["serving"], site.address, site.address) ), ("%s", [peer.connection.id for peer in site.peers.values() if peer.connection and peer.connection.connected]), ("%s/%s/%s", ( @@ -133,10 +139,10 @@ class UiRequestPlugin(object): ("%.0fkB", site.settings.get("bytes_sent", 0) / 1024), ("%.0fkB", site.settings.get("bytes_recv", 0) / 1024), ]) - yield "" yield "
address connected peers content.json out in
" @@ -155,7 +161,6 @@ class UiRequestPlugin(object): # Object types - obj_count = {} for obj in gc.get_objects(): obj_type = str(type(obj)) @@ -325,9 +330,12 @@ class UiRequestPlugin(object): ] if not refs: continue - yield "%.1fkb %s... " % ( - float(sys.getsizeof(obj)) / 1024, cgi.escape(str(obj)), cgi.escape(str(obj)[0:100].ljust(100)) - ) + try: + yield "%.1fkb %s... " % ( + float(sys.getsizeof(obj)) / 1024, cgi.escape(str(obj)), cgi.escape(str(obj)[0:100].ljust(100)) + ) + except: + continue for ref in refs: yield " [" if "object at" in str(ref) or len(str(ref)) > 100: @@ -445,12 +453,21 @@ class UiRequestPlugin(object): from cStringIO import StringIO data = StringIO("Hello" * 1024 * 1024) # 5m - with benchmark("sha512 x 100 000", 1): + with benchmark("sha256 5M x 10", 0.6): for i in range(10): - for y in range(10000): - hash = CryptHash.sha512sum(data) + data.seek(0) + hash = CryptHash.sha256sum(data) yield "." - valid = "cf83e1357eefb8bdf1542850d66d8007d620e4050b5715dc83f4a921d36ce9ce" + valid = "8cd629d9d6aff6590da8b80782a5046d2673d5917b99d5603c3dcb4005c45ffa" + assert hash == valid, "%s != %s" % (hash, valid) + + data = StringIO("Hello" * 1024 * 1024) # 5m + with benchmark("sha512 5M x 10", 0.6): + for i in range(10): + data.seek(0) + hash = CryptHash.sha512sum(data) + yield "." + valid = "9ca7e855d430964d5b55b114e95c6bbb114a6d478f6485df93044d87b108904d" assert hash == valid, "%s != %s" % (hash, valid) with benchmark("os.urandom(256) x 100 000", 0.65): diff --git a/plugins/disabled-Bootstrapper/BootstrapperDb.py b/plugins/disabled-Bootstrapper/BootstrapperDb.py new file mode 100644 index 00000000..b07ac471 --- /dev/null +++ b/plugins/disabled-Bootstrapper/BootstrapperDb.py @@ -0,0 +1,157 @@ +import time +import re + +import gevent + +from Config import config +from Db import Db +from util import helper + + +class BootstrapperDb(Db): + def __init__(self): + self.version = 6 + self.hash_ids = {} # hash -> id cache + super(BootstrapperDb, self).__init__({"db_name": "Bootstrapper"}, "%s/bootstrapper.db" % config.data_dir) + self.foreign_keys = True + self.checkTables() + self.updateHashCache() + gevent.spawn(self.cleanup) + + def cleanup(self): + while 1: + self.execute("DELETE FROM peer WHERE date_announced < DATETIME('now', '-40 minute')") + time.sleep(4*60) + + def updateHashCache(self): + res = self.execute("SELECT * FROM hash") + self.hash_ids = {str(row["hash"]): row["hash_id"] for row in res} + self.log.debug("Loaded %s hash_ids" % len(self.hash_ids)) + + def checkTables(self): + version = int(self.execute("PRAGMA user_version").fetchone()[0]) + self.log.debug("Db version: %s, needed: %s" % (version, self.version)) + if version < self.version: + self.createTables() + else: + self.execute("VACUUM") + + def createTables(self): + # Delete all tables + self.execute("PRAGMA writable_schema = 1") + self.execute("DELETE FROM sqlite_master WHERE type IN ('table', 'index', 'trigger')") + self.execute("PRAGMA writable_schema = 0") + self.execute("VACUUM") + self.execute("PRAGMA INTEGRITY_CHECK") + # Create new tables + self.execute(""" + CREATE TABLE peer ( + peer_id INTEGER PRIMARY KEY ASC AUTOINCREMENT NOT NULL UNIQUE, + port INTEGER NOT NULL, + ip4 TEXT, + onion TEXT, + date_added DATETIME DEFAULT (CURRENT_TIMESTAMP), + date_announced DATETIME DEFAULT (CURRENT_TIMESTAMP) + ); + """) + + self.execute(""" + CREATE TABLE peer_to_hash ( + peer_to_hash_id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE NOT NULL, + peer_id INTEGER REFERENCES peer (peer_id) ON DELETE CASCADE, + hash_id INTEGER REFERENCES hash (hash_id) + ); + """) + self.execute("CREATE INDEX peer_id ON peer_to_hash (peer_id);") + self.execute("CREATE INDEX hash_id ON peer_to_hash (hash_id);") + + self.execute(""" + CREATE TABLE hash ( + hash_id INTEGER PRIMARY KEY AUTOINCREMENT UNIQUE NOT NULL, + hash BLOB UNIQUE NOT NULL, + date_added DATETIME DEFAULT (CURRENT_TIMESTAMP) + ); + """) + self.execute("PRAGMA user_version = %s" % self.version) + + def getHashId(self, hash): + if hash not in self.hash_ids: + self.log.debug("New hash: %s" % repr(hash)) + self.execute("INSERT OR IGNORE INTO hash ?", {"hash": buffer(hash)}) + self.hash_ids[hash] = self.cur.cursor.lastrowid + return self.hash_ids[hash] + + def peerAnnounce(self, ip4=None, onion=None, port=None, hashes=[], onion_signed=False, delete_missing_hashes=False): + hashes_ids_announced = [] + for hash in hashes: + hashes_ids_announced.append(self.getHashId(hash)) + + if not ip4 and not onion: + return 0 + + # Check user + if onion: + res = self.execute("SELECT * FROM peer WHERE ? LIMIT 1", {"onion": onion}) + else: + res = self.execute("SELECT * FROM peer WHERE ? LIMIT 1", {"ip4": ip4, "port": port}) + + user_row = res.fetchone() + if user_row: + peer_id = user_row["peer_id"] + self.execute("UPDATE peer SET date_announced = DATETIME('now') WHERE ?", {"peer_id": peer_id}) + else: + self.log.debug("New peer: %s %s signed: %s" % (ip4, onion, onion_signed)) + if onion and not onion_signed: + return len(hashes) + self.execute("INSERT INTO peer ?", {"ip4": ip4, "onion": onion, "port": port}) + peer_id = self.cur.cursor.lastrowid + + # Check user's hashes + res = self.execute("SELECT * FROM peer_to_hash WHERE ?", {"peer_id": peer_id}) + hash_ids_db = [row["hash_id"] for row in res] + if hash_ids_db != hashes_ids_announced: + hash_ids_added = set(hashes_ids_announced) - set(hash_ids_db) + hash_ids_removed = set(hash_ids_db) - set(hashes_ids_announced) + if not onion or onion_signed: + for hash_id in hash_ids_added: + self.execute("INSERT INTO peer_to_hash ?", {"peer_id": peer_id, "hash_id": hash_id}) + if hash_ids_removed and delete_missing_hashes: + self.execute("DELETE FROM peer_to_hash WHERE ?", {"peer_id": peer_id, "hash_id": list(hash_ids_removed)}) + + return len(hash_ids_added) + len(hash_ids_removed) + else: + return 0 + + def peerList(self, hash, ip4=None, onions=[], port=None, limit=30, need_types=["ip4", "onion"]): + hash_peers = {"ip4": [], "onion": []} + if limit == 0: + return hash_peers + hashid = self.getHashId(hash) + + where = "hash_id = :hashid" + if onions: + onions_escaped = ["'%s'" % re.sub("[^a-z0-9,]", "", onion) for onion in onions] + where += " AND (onion NOT IN (%s) OR onion IS NULL)" % ",".join(onions_escaped) + elif ip4: + where += " AND (NOT (ip4 = :ip4 AND port = :port) OR ip4 IS NULL)" + + query = """ + SELECT ip4, port, onion + FROM peer_to_hash + LEFT JOIN peer USING (peer_id) + WHERE %s + LIMIT :limit + """ % where + res = self.execute(query, {"hashid": hashid, "ip4": ip4, "onions": onions, "port": port, "limit": limit}) + + for row in res: + if row["ip4"] and "ip4" in need_types: + hash_peers["ip4"].append( + helper.packAddress(row["ip4"], row["port"]) + ) + if row["onion"] and "onion" in need_types: + hash_peers["onion"].append( + helper.packOnionAddress(row["onion"], row["port"]) + ) + + return hash_peers diff --git a/plugins/disabled-Bootstrapper/BootstrapperPlugin.py b/plugins/disabled-Bootstrapper/BootstrapperPlugin.py new file mode 100644 index 00000000..7d4360c2 --- /dev/null +++ b/plugins/disabled-Bootstrapper/BootstrapperPlugin.py @@ -0,0 +1,105 @@ +import time + +from Plugin import PluginManager +from BootstrapperDb import BootstrapperDb +from Crypt import CryptRsa + +if "db" not in locals().keys(): # Share durin reloads + db = BootstrapperDb() + + +@PluginManager.registerTo("FileRequest") +class FileRequestPlugin(object): + def actionAnnounce(self, params): + hashes = params["hashes"] + + if "onion_signs" in params and len(params["onion_signs"]) == len(hashes): + # Check if all sign is correct + if time.time() - float(params["onion_sign_this"]) < 3*60: # Peer has 3 minute to sign the message + onions_signed = [] + # Check onion signs + for onion_publickey, onion_sign in params["onion_signs"].items(): + if CryptRsa.verify(params["onion_sign_this"], onion_publickey, onion_sign): + onions_signed.append(CryptRsa.publickeyToOnion(onion_publickey)) + else: + break + # Check if the same onion addresses signed as the announced onces + if sorted(onions_signed) == sorted(params["onions"]): + all_onions_signed = True + else: + all_onions_signed = False + else: + # Onion sign this out of 3 minute + all_onions_signed = False + else: + # Incorrect signs number + all_onions_signed = False + + if "ip4" in params["add"] and self.connection.ip != "127.0.0.1" and not self.connection.ip.endswith(".onion"): + ip4 = self.connection.ip + else: + ip4 = None + + # Separatley add onions to sites or at once if no onions present + hashes_changed = 0 + i = 0 + for onion in params.get("onions", []): + hashes_changed += db.peerAnnounce( + onion=onion, + port=params["port"], + hashes=[hashes[i]], + onion_signed=all_onions_signed + ) + i += 1 + # Announce all sites if ip4 defined + if ip4: + hashes_changed += db.peerAnnounce( + ip4=ip4, + port=params["port"], + hashes=hashes, + delete_missing_hashes=params.get("delete") + ) + + # Query sites + back = {} + peers = [] + if params.get("onions") and not all_onions_signed and hashes_changed: + back["onion_sign_this"] = "%.0f" % time.time() # Send back nonce for signing + + for hash in hashes: + hash_peers = db.peerList( + hash, + ip4=self.connection.ip, onions=params.get("onions"), port=params["port"], + limit=min(30, params["need_num"]), need_types=params["need_types"] + ) + peers.append(hash_peers) + + back["peers"] = peers + self.response(back) + + +@PluginManager.registerTo("UiRequest") +class UiRequestPlugin(object): + def actionStatsBootstrapper(self): + self.sendHeader() + + # Style + yield """ + + """ + + hash_rows = db.execute("SELECT * FROM hash").fetchall() + for hash_row in hash_rows: + peer_rows = db.execute( + "SELECT * FROM peer LEFT JOIN peer_to_hash USING (peer_id) WHERE hash_id = :hash_id", + {"hash_id": hash_row["hash_id"]} + ).fetchall() + + yield "
%s (added: %s, peers: %s)
" % ( + str(hash_row["hash"]).encode("hex"), hash_row["date_added"], len(peer_rows) + ) + for peer_row in peer_rows: + yield " - {ip4: <30} {onion: <30} added: {date_added}, announced: {date_announced}
".format(**dict(peer_row)) diff --git a/plugins/disabled-Bootstrapper/Test/TestBootstrapper.py b/plugins/disabled-Bootstrapper/Test/TestBootstrapper.py new file mode 100644 index 00000000..e49bfd3e --- /dev/null +++ b/plugins/disabled-Bootstrapper/Test/TestBootstrapper.py @@ -0,0 +1,179 @@ +import hashlib +import os + +import pytest + +from Bootstrapper import BootstrapperPlugin +from Bootstrapper.BootstrapperDb import BootstrapperDb +from Peer import Peer +from Crypt import CryptRsa +from util import helper + + +@pytest.fixture() +def bootstrapper_db(request): + BootstrapperPlugin.db.close() + BootstrapperPlugin.db = BootstrapperDb() + BootstrapperPlugin.db.createTables() # Reset db + BootstrapperPlugin.db.cur.logging = True + + def cleanup(): + BootstrapperPlugin.db.close() + os.unlink(BootstrapperPlugin.db.db_path) + + request.addfinalizer(cleanup) + return BootstrapperPlugin.db + + +@pytest.mark.usefixtures("resetSettings") +class TestBootstrapper: + def testIp4(self, file_server, bootstrapper_db): + peer = Peer("127.0.0.1", 1544, connection_server=file_server) + hash1 = hashlib.sha256("site1").digest() + hash2 = hashlib.sha256("site2").digest() + hash3 = hashlib.sha256("site3").digest() + + # Verify empty result + res = peer.request("announce", { + "hashes": [hash1, hash2], + "port": 15441, "need_types": ["ip4"], "need_num": 10, "add": ["ip4"] + }) + + assert len(res["peers"][0]["ip4"]) == 0 # Empty result + + # Verify added peer on previous request + bootstrapper_db.peerAnnounce(ip4="1.2.3.4", port=15441, hashes=[hash1, hash2], delete_missing_hashes=True) + + res = peer.request("announce", { + "hashes": [hash1, hash2], + "port": 15441, "need_types": ["ip4"], "need_num": 10, "add": ["ip4"] + }) + assert len(res["peers"][0]["ip4"]) == 1 + assert len(res["peers"][1]["ip4"]) == 1 + + # hash2 deleted from 1.2.3.4 + bootstrapper_db.peerAnnounce(ip4="1.2.3.4", port=15441, hashes=[hash1], delete_missing_hashes=True) + res = peer.request("announce", { + "hashes": [hash1, hash2], + "port": 15441, "need_types": ["ip4"], "need_num": 10, "add": ["ip4"] + }) + assert len(res["peers"][0]["ip4"]) == 1 + assert len(res["peers"][1]["ip4"]) == 0 + + # Announce 3 hash again + bootstrapper_db.peerAnnounce(ip4="1.2.3.4", port=15441, hashes=[hash1, hash2, hash3], delete_missing_hashes=True) + res = peer.request("announce", { + "hashes": [hash1, hash2, hash3], + "port": 15441, "need_types": ["ip4"], "need_num": 10, "add": ["ip4"] + }) + assert len(res["peers"][0]["ip4"]) == 1 + assert len(res["peers"][1]["ip4"]) == 1 + assert len(res["peers"][2]["ip4"]) == 1 + + # Single hash announce + res = peer.request("announce", { + "hashes": [hash1], "port": 15441, "need_types": ["ip4"], "need_num": 10, "add": ["ip4"] + }) + assert len(res["peers"][0]["ip4"]) == 1 + + # Test DB cleanup + assert bootstrapper_db.execute("SELECT COUNT(*) AS num FROM peer").fetchone()["num"] == 1 # 127.0.0.1 never get added to db + + # Delete peers + bootstrapper_db.execute("DELETE FROM peer WHERE ip4 = '1.2.3.4'") + assert bootstrapper_db.execute("SELECT COUNT(*) AS num FROM peer_to_hash").fetchone()["num"] == 0 + + assert bootstrapper_db.execute("SELECT COUNT(*) AS num FROM hash").fetchone()["num"] == 3 # 3 sites + assert bootstrapper_db.execute("SELECT COUNT(*) AS num FROM peer").fetchone()["num"] == 0 # 0 peer + + def testPassive(self, file_server, bootstrapper_db): + peer = Peer("127.0.0.1", 1544, connection_server=file_server) + hash1 = hashlib.sha256("hash1").digest() + + bootstrapper_db.peerAnnounce(ip4=None, port=15441, hashes=[hash1]) + res = peer.request("announce", { + "hashes": [hash1], "port": 15441, "need_types": ["ip4"], "need_num": 10, "add": [] + }) + + assert len(res["peers"][0]["ip4"]) == 0 # Empty result + + def testAddOnion(self, file_server, site, bootstrapper_db, tor_manager): + onion1 = tor_manager.addOnion() + onion2 = tor_manager.addOnion() + peer = Peer("127.0.0.1", 1544, connection_server=file_server) + hash1 = hashlib.sha256("site1").digest() + hash2 = hashlib.sha256("site2").digest() + + bootstrapper_db.peerAnnounce(ip4="1.2.3.4", port=1234, hashes=[hash1, hash2]) + res = peer.request("announce", { + "onions": [onion1, onion2], + "hashes": [hash1, hash2], "port": 15441, "need_types": ["ip4", "onion"], "need_num": 10, "add": ["onion"] + }) + assert len(res["peers"][0]["ip4"]) == 1 + assert "onion_sign_this" in res + + # Onion address not added yet + site_peers = bootstrapper_db.peerList(ip4="1.2.3.4", port=1234, hash=hash1) + assert len(site_peers["onion"]) == 0 + assert "onion_sign_this" in res + + # Sign the nonces + sign1 = CryptRsa.sign(res["onion_sign_this"], tor_manager.getPrivatekey(onion1)) + sign2 = CryptRsa.sign(res["onion_sign_this"], tor_manager.getPrivatekey(onion2)) + + # Bad sign (different address) + res = peer.request("announce", { + "onions": [onion1], "onion_sign_this": res["onion_sign_this"], + "onion_signs": {tor_manager.getPublickey(onion2): sign2}, + "hashes": [hash1], "port": 15441, "need_types": ["ip4", "onion"], "need_num": 10, "add": ["onion"] + }) + assert "onion_sign_this" in res + site_peers1 = bootstrapper_db.peerList(ip4="1.2.3.4", port=1234, hash=hash1) + assert len(site_peers1["onion"]) == 0 # Not added + + # Bad sign (missing one) + res = peer.request("announce", { + "onions": [onion1, onion2], "onion_sign_this": res["onion_sign_this"], + "onion_signs": {tor_manager.getPublickey(onion1): sign1}, + "hashes": [hash1, hash2], "port": 15441, "need_types": ["ip4", "onion"], "need_num": 10, "add": ["onion"] + }) + assert "onion_sign_this" in res + site_peers1 = bootstrapper_db.peerList(ip4="1.2.3.4", port=1234, hash=hash1) + assert len(site_peers1["onion"]) == 0 # Not added + + # Good sign + res = peer.request("announce", { + "onions": [onion1, onion2], "onion_sign_this": res["onion_sign_this"], + "onion_signs": {tor_manager.getPublickey(onion1): sign1, tor_manager.getPublickey(onion2): sign2}, + "hashes": [hash1, hash2], "port": 15441, "need_types": ["ip4", "onion"], "need_num": 10, "add": ["onion"] + }) + assert "onion_sign_this" not in res + + # Onion addresses added + site_peers1 = bootstrapper_db.peerList(ip4="1.2.3.4", port=1234, hash=hash1) + assert len(site_peers1["onion"]) == 1 + site_peers2 = bootstrapper_db.peerList(ip4="1.2.3.4", port=1234, hash=hash2) + assert len(site_peers2["onion"]) == 1 + + assert site_peers1["onion"][0] != site_peers2["onion"][0] + assert helper.unpackOnionAddress(site_peers1["onion"][0])[0] == onion1+".onion" + assert helper.unpackOnionAddress(site_peers2["onion"][0])[0] == onion2+".onion" + + tor_manager.delOnion(onion1) + tor_manager.delOnion(onion2) + + def testRequestPeers(self, file_server, site, bootstrapper_db, tor_manager): + site.connection_server = file_server + hash = hashlib.sha256(site.address).digest() + + # Request peers from tracker + assert len(site.peers) == 0 + bootstrapper_db.peerAnnounce(ip4="1.2.3.4", port=1234, hashes=[hash]) + site.announceTracker("zero", "127.0.0.1:1544") + assert len(site.peers) == 1 + + # Test onion address store + bootstrapper_db.peerAnnounce(onion="bka4ht2bzxchy44r", port=1234, hashes=[hash], onion_signed=True) + site.announceTracker("zero", "127.0.0.1:1544") + assert len(site.peers) == 2 + assert "bka4ht2bzxchy44r.onion:1234" in site.peers diff --git a/plugins/disabled-Bootstrapper/Test/conftest.py b/plugins/disabled-Bootstrapper/Test/conftest.py new file mode 100644 index 00000000..8c1df5b2 --- /dev/null +++ b/plugins/disabled-Bootstrapper/Test/conftest.py @@ -0,0 +1 @@ +from src.Test.conftest import * \ No newline at end of file diff --git a/plugins/disabled-Bootstrapper/Test/pytest.ini b/plugins/disabled-Bootstrapper/Test/pytest.ini new file mode 100644 index 00000000..d09210d1 --- /dev/null +++ b/plugins/disabled-Bootstrapper/Test/pytest.ini @@ -0,0 +1,5 @@ +[pytest] +python_files = Test*.py +addopts = -rsxX -v --durations=6 +markers = + webtest: mark a test as a webtest. \ No newline at end of file diff --git a/plugins/disabled-Bootstrapper/__init__.py b/plugins/disabled-Bootstrapper/__init__.py new file mode 100644 index 00000000..ca533eac --- /dev/null +++ b/plugins/disabled-Bootstrapper/__init__.py @@ -0,0 +1 @@ +import BootstrapperPlugin \ No newline at end of file diff --git a/src/Config.py b/src/Config.py index f82873b5..86c48182 100644 --- a/src/Config.py +++ b/src/Config.py @@ -7,8 +7,8 @@ import ConfigParser class Config(object): def __init__(self, argv): - self.version = "0.3.4" - self.rev = 668 + self.version = "0.3.5" + self.rev = 830 self.argv = argv self.action = None self.createParser() @@ -30,11 +30,13 @@ class Config(object): # Create command line arguments def createArguments(self): trackers = [ + "zero://boot3rdez4rzn36x.onion:15441", + "zero://boot.zeronet.io#f36ca555bee6ba216b14d10f38c16f7769ff064e0e37d887603548cc2e64191d:15441", "udp://tracker.coppersurfer.tk:6969", "udp://tracker.leechers-paradise.org:6969", "udp://9.rarbg.com:2710", "http://tracker.aletorrenty.pl:2710/announce", - "http://tracker.skyts.net:6969/announce", + "http://explodie.org:6969/announce", "http://torrent.gresille.org/announce" ] # Platform specific @@ -138,6 +140,7 @@ class Config(object): self.parser.add_argument('--disable_encryption', help='Disable connection encryption', action='store_true') self.parser.add_argument('--disable_sslcompression', help='Disable SSL compression to save memory', type='bool', choices=[True, False], default=True) + self.parser.add_argument('--keep_ssl_cert', help='Disable new SSL cert generation on startup', action='store_true') self.parser.add_argument('--use_tempfiles', help='Use temporary files when downloading (experimental)', type='bool', choices=[True, False], default=False) self.parser.add_argument('--stream_downloads', help='Stream download directly to files (experimental)', @@ -148,6 +151,10 @@ class Config(object): self.parser.add_argument('--coffeescript_compiler', help='Coffeescript compiler for developing', default=coffeescript, metavar='executable_path') + self.parser.add_argument('--tor', help='enable: Use only for Tor peers, always: Use Tor for every connection', choices=["disable", "enable", "always"], default='enable') + self.parser.add_argument('--tor_controller', help='Tor controller address', metavar='ip:port', default='127.0.0.1:9051') + self.parser.add_argument('--tor_proxy', help='Tor proxy address', metavar='ip:port', default='127.0.0.1:9050') + self.parser.add_argument('--version', action='version', version='ZeroNet %s r%s' % (self.version, self.rev)) return self.parser diff --git a/src/Connection/Connection.py b/src/Connection/Connection.py index 8c7063be..ef10a632 100644 --- a/src/Connection/Connection.py +++ b/src/Connection/Connection.py @@ -1,5 +1,6 @@ import socket import time +import hashlib import gevent import msgpack @@ -8,20 +9,25 @@ from Config import config from Debug import Debug from util import StreamingMsgpack from Crypt import CryptConnection +from Site import SiteManager class Connection(object): __slots__ = ( - "sock", "sock_wrapped", "ip", "port", "id", "protocol", "type", "server", "unpacker", "req_id", + "sock", "sock_wrapped", "ip", "port", "cert_pin", "site_lock", "id", "protocol", "type", "server", "unpacker", "req_id", "handshake", "crypt", "connected", "event_connected", "closed", "start_time", "last_recv_time", "last_message_time", "last_send_time", "last_sent_time", "incomplete_buff_recv", "bytes_recv", "bytes_sent", "last_ping_delay", "last_req_time", "last_cmd", "name", "updateName", "waiting_requests", "waiting_streams" ) - def __init__(self, server, ip, port, sock=None): + def __init__(self, server, ip, port, sock=None, site_lock=None): self.sock = sock self.ip = ip self.port = port + self.cert_pin = None + if "#" in ip: + self.ip, self.cert_pin = ip.split("#") + self.site_lock = site_lock # Only this site requests allowed (for Tor) self.id = server.last_connection_id server.last_connection_id += 1 self.protocol = "?" @@ -73,17 +79,23 @@ class Connection(object): def connect(self): self.log("Connecting...") self.type = "out" - self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - self.sock.connect((self.ip, int(self.port))) + if self.ip.endswith(".onion"): + if not self.server.tor_manager or not self.server.tor_manager.enabled: + raise Exception("Can't connect to onion addresses, no Tor controller present") + self.sock = self.server.tor_manager.createSocket(self.ip, self.port) + else: + self.sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + self.sock.connect((self.ip, int(self.port))) - # Implicit SSL in the future - # self.sock = CryptConnection.manager.wrapSocket(self.sock, "tls-rsa") - # self.sock.do_handshake() - # self.crypt = "tls-rsa" - # self.sock_wrapped = True + # Implicit SSL + if self.cert_pin: + self.sock = CryptConnection.manager.wrapSocket(self.sock, "tls-rsa", cert_pin=self.cert_pin) + self.sock.do_handshake() + self.crypt = "tls-rsa" + self.sock_wrapped = True # Detect protocol - self.send({"cmd": "handshake", "req_id": 0, "params": self.handshakeInfo()}) + self.send({"cmd": "handshake", "req_id": 0, "params": self.getHandshakeInfo()}) event_connected = self.event_connected gevent.spawn(self.messageLoop) return event_connected.get() # Wait for handshake @@ -92,14 +104,15 @@ class Connection(object): def handleIncomingConnection(self, sock): self.log("Incoming connection...") self.type = "in" - try: - if sock.recv(1, gevent.socket.MSG_PEEK) == "\x16": - self.log("Crypt in connection using implicit SSL") - self.sock = CryptConnection.manager.wrapSocket(self.sock, "tls-rsa", True) - self.sock_wrapped = True - self.crypt = "tls-rsa" - except Exception, err: - self.log("Socket peek error: %s" % Debug.formatException(err)) + if self.ip != "127.0.0.1": # Clearnet: Check implicit SSL + try: + if sock.recv(1, gevent.socket.MSG_PEEK) == "\x16": + self.log("Crypt in connection using implicit SSL") + self.sock = CryptConnection.manager.wrapSocket(self.sock, "tls-rsa", True) + self.sock_wrapped = True + self.crypt = "tls-rsa" + except Exception, err: + self.log("Socket peek error: %s" % Debug.formatException(err)) self.messageLoop() # Message loop for connection @@ -142,28 +155,60 @@ class Connection(object): self.close() # MessageLoop ended, close connection # My handshake info - def handshakeInfo(self): - return { + def getHandshakeInfo(self): + # No TLS for onion connections + if self.ip.endswith(".onion"): + crypt_supported = [] + else: + crypt_supported = CryptConnection.manager.crypt_supported + # No peer id for onion connections + if self.ip.endswith(".onion") or self.ip == "127.0.0.1": + peer_id = "" + else: + peer_id = self.server.peer_id + # Setup peer lock from requested onion address + if self.handshake and self.handshake.get("target_ip", "").endswith(".onion"): + target_onion = self.handshake.get("target_ip").replace(".onion", "") # My onion address + onion_sites = {v: k for k, v in self.server.tor_manager.site_onions.items()} # Inverse, Onion: Site address + self.site_lock = onion_sites.get(target_onion) + if not self.site_lock: + self.server.log.error("Unknown target onion address: %s" % target_onion) + self.site_lock = "unknown" + + handshake = { "version": config.version, "protocol": "v2", - "peer_id": self.server.peer_id, + "peer_id": peer_id, "fileserver_port": self.server.port, "port_opened": self.server.port_opened, + "target_ip": self.ip, "rev": config.rev, - "crypt_supported": CryptConnection.manager.crypt_supported, + "crypt_supported": crypt_supported, "crypt": self.crypt } + if self.site_lock: + handshake["onion"] = self.server.tor_manager.getOnion(self.site_lock) + elif self.ip.endswith(".onion"): + handshake["onion"] = self.server.tor_manager.getOnion("global") + + return handshake def setHandshake(self, handshake): self.handshake = handshake - if handshake.get("port_opened", None) is False: # Not connectable + if handshake.get("port_opened", None) is False and not "onion" in handshake: # Not connectable self.port = 0 else: self.port = handshake["fileserver_port"] # Set peer fileserver port + if handshake.get("onion") and not self.ip.endswith(".onion"): # Set incoming connection's onion address + self.ip = handshake["onion"] + ".onion" + self.updateName() + # Check if we can encrypt the connection if handshake.get("crypt_supported") and handshake["peer_id"] not in self.server.broken_ssl_peer_ids: - if handshake.get("crypt"): # Recommended crypt by server + if self.ip.endswith(".onion"): + crypt = None + elif handshake.get("crypt"): # Recommended crypt by server crypt = handshake["crypt"] else: # Select the best supported on both sides crypt = CryptConnection.manager.selectCrypt(handshake["crypt_supported"]) @@ -193,30 +238,21 @@ class Connection(object): self.crypt = message["crypt"] server = (self.type == "in") self.log("Crypt out connection using: %s (server side: %s)..." % (self.crypt, server)) - self.sock = CryptConnection.manager.wrapSocket(self.sock, self.crypt, server) + self.sock = CryptConnection.manager.wrapSocket(self.sock, self.crypt, server, cert_pin=self.cert_pin) self.sock.do_handshake() + self.sock_wrapped = True + + if not self.sock_wrapped and self.cert_pin: + self.log("Crypt connection error: Socket not encrypted, but certificate pin present") + self.close() + return + self.setHandshake(message) else: self.log("Unknown response: %s" % message) elif message.get("cmd"): # Handhsake request if message["cmd"] == "handshake": - if config.debug_socket: - self.log("Handshake request: %s" % message) - self.setHandshake(message["params"]) - data = self.handshakeInfo() - data["cmd"] = "response" - data["to"] = message["req_id"] - self.send(data) # Send response to handshake - # Sent crypt request to client - if self.crypt and not self.sock_wrapped: - server = (self.type == "in") - self.log("Crypt in connection using: %s (server side: %s)..." % (self.crypt, server)) - try: - self.sock = CryptConnection.manager.wrapSocket(self.sock, self.crypt, server) - self.sock_wrapped = True - except Exception, err: - self.log("Crypt connection error: %s, adding peerid %s as broken ssl." % (err, message["params"]["peer_id"])) - self.server.broken_ssl_peer_ids[message["params"]["peer_id"]] = True + self.handleHandshake(message) else: self.server.handleRequest(self, message) else: # Old style response, no req_id definied @@ -226,6 +262,30 @@ class Connection(object): self.waiting_requests[last_req_id].set(message) del self.waiting_requests[last_req_id] # Remove from waiting request + # Incoming handshake set request + def handleHandshake(self, message): + if config.debug_socket: + self.log("Handshake request: %s" % message) + self.setHandshake(message["params"]) + data = self.getHandshakeInfo() + data["cmd"] = "response" + data["to"] = message["req_id"] + self.send(data) # Send response to handshake + # Sent crypt request to client + if self.crypt and not self.sock_wrapped: + server = (self.type == "in") + self.log("Crypt in connection using: %s (server side: %s)..." % (self.crypt, server)) + try: + self.sock = CryptConnection.manager.wrapSocket(self.sock, self.crypt, server, cert_pin=self.cert_pin) + self.sock_wrapped = True + except Exception, err: + self.log("Crypt connection error: %s, adding peerid %s as broken ssl." % (err, message["params"]["peer_id"])) + self.server.broken_ssl_peer_ids[message["params"]["peer_id"]] = True + + if not self.sock_wrapped and self.cert_pin: + self.log("Crypt connection error: Socket not encrypted, but certificate pin present") + self.close() + # Stream socket directly to a file def handleStream(self, message): if config.debug_socket: diff --git a/src/Connection/ConnectionServer.py b/src/Connection/ConnectionServer.py index 72c53c83..8e19761b 100644 --- a/src/Connection/ConnectionServer.py +++ b/src/Connection/ConnectionServer.py @@ -1,6 +1,4 @@ import logging -import random -import string import time import sys @@ -14,6 +12,7 @@ from Connection import Connection from Config import config from Crypt import CryptConnection from Crypt import CryptHash +from Tor import TorManager class ConnectionServer: @@ -24,7 +23,13 @@ class ConnectionServer: self.log = logging.getLogger("ConnServer") self.port_opened = None + if config.tor != "disabled": + self.tor_manager = TorManager(self.ip, self.port) + else: + self.tor_manager = None + self.connections = [] # Connections + self.whitelist = ("127.0.0.1",) # No flood protection on this ips self.ip_incoming = {} # Incoming connections from ip in the last minute to avoid connection flood self.broken_ssl_peer_ids = {} # Peerids of broken ssl connections self.ips = {} # Connection by ip @@ -41,7 +46,7 @@ class ConnectionServer: # Check msgpack version if msgpack.version[0] == 0 and msgpack.version[1] < 4: self.log.error( - "Error: Unsupported msgpack version: %s (<0.4.0), please run `sudo pip install msgpack-python --upgrade`" % + "Error: Unsupported msgpack version: %s (<0.4.0), please run `sudo apt-get install python-pip; sudo pip install msgpack-python --upgrade`" % str(msgpack.version) ) sys.exit(0) @@ -74,7 +79,7 @@ class ConnectionServer: ip, port = addr # Connection flood protection - if ip in self.ip_incoming: + if ip in self.ip_incoming and ip not in self.whitelist: self.ip_incoming[ip] += 1 if self.ip_incoming[ip] > 3: # Allow 3 in 1 minute from same ip self.log.debug("Connection flood detected from %s" % ip) @@ -89,10 +94,15 @@ class ConnectionServer: self.ips[ip] = connection connection.handleIncomingConnection(sock) - def getConnection(self, ip=None, port=None, peer_id=None, create=True): + def getConnection(self, ip=None, port=None, peer_id=None, create=True, site=None): + if ip.endswith(".onion") and self.tor_manager.start_onions and site: # Site-unique connection for Tor + key = ip + site.address + else: + key = ip + # Find connection by ip - if ip in self.ips: - connection = self.ips[ip] + if key in self.ips: + connection = self.ips[key] if not peer_id or connection.handshake.get("peer_id") == peer_id: # Filter by peer_id if not connection.connected and create: succ = connection.event_connected.get() # Wait for connection @@ -105,6 +115,9 @@ class ConnectionServer: if connection.ip == ip: if peer_id and connection.handshake.get("peer_id") != peer_id: # Does not match continue + if ip.endswith(".onion") and self.tor_manager.start_onions and connection.site_lock != site.address: + # For different site + continue if not connection.connected and create: succ = connection.event_connected.get() # Wait for connection if not succ: @@ -116,8 +129,11 @@ class ConnectionServer: if port == 0: raise Exception("This peer is not connectable") try: - connection = Connection(self, ip, port) - self.ips[ip] = connection + if ip.endswith(".onion") and self.tor_manager.start_onions and site: # Lock connection to site + connection = Connection(self, ip, port, site_lock=site.address) + else: + connection = Connection(self, ip, port) + self.ips[key] = connection self.connections.append(connection) succ = connection.connect() if not succ: @@ -134,14 +150,22 @@ class ConnectionServer: def removeConnection(self, connection): self.log.debug("Removing %s..." % connection) - if self.ips.get(connection.ip) == connection: # Delete if same as in registry + # Delete if same as in registry + if self.ips.get(connection.ip) == connection: del self.ips[connection.ip] + # Site locked connection + if connection.site_lock and self.ips.get(connection.ip + connection.site_lock) == connection: + del self.ips[connection.ip + connection.site_lock] + # Cert pinned connection + if connection.cert_pin and self.ips.get(connection.ip + "#" + connection.cert_pin) == connection: + del self.ips[connection.ip + "#" + connection.cert_pin] + if connection in self.connections: self.connections.remove(connection) def checkConnections(self): while self.running: - time.sleep(60) # Sleep 1 min + time.sleep(60) # Check every minute self.ip_incoming = {} # Reset connected ips counter self.broken_ssl_peer_ids = {} # Reset broken ssl peerids count for connection in self.connections[:]: # Make a copy @@ -151,7 +175,10 @@ class ConnectionServer: # Delete the unpacker if not needed del connection.unpacker connection.unpacker = None - connection.log("Unpacker deleted") + + elif connection.last_cmd == "announce" and idle > 20: # Bootstrapper connection close after 20 sec + connection.log("[Cleanup] Tracker connection: %s" % idle) + connection.close() if idle > 60 * 60: # Wake up after 1h diff --git a/src/Crypt/CryptConnection.py b/src/Crypt/CryptConnection.py index fb2c0920..61d96acc 100644 --- a/src/Crypt/CryptConnection.py +++ b/src/Crypt/CryptConnection.py @@ -2,6 +2,7 @@ import sys import logging import os import ssl +import hashlib from Config import config from util import SslPatch @@ -29,20 +30,26 @@ class CryptConnectionManager: # Wrap socket for crypt # Return: wrapped socket - def wrapSocket(self, sock, crypt, server=False): + def wrapSocket(self, sock, crypt, server=False, cert_pin=None): if crypt == "tls-rsa": ciphers = "ECDHE-ECDSA-AES128-GCM-SHA256:ECDHE-RSA-AES128-GCM-SHA256:AES128-GCM-SHA256:AES128-SHA256:HIGH:" ciphers += "!aNULL:!eNULL:!EXPORT:!DSS:!DES:!RC4:!3DES:!MD5:!PSK" if server: - return ssl.wrap_socket( + sock_wrapped = ssl.wrap_socket( sock, server_side=server, keyfile='%s/key-rsa.pem' % config.data_dir, certfile='%s/cert-rsa.pem' % config.data_dir, ciphers=ciphers) else: - return ssl.wrap_socket(sock, ciphers=ciphers) + sock_wrapped = ssl.wrap_socket(sock, ciphers=ciphers) + if cert_pin: + cert_hash = hashlib.sha256(sock_wrapped.getpeercert(True)).hexdigest() + assert cert_hash == cert_pin, "Socket certificate does not match (%s != %s)" % (cert_hash, cert_pin) + return sock_wrapped else: return sock def removeCerts(self): + if config.keep_ssl_cert: + return False for file_name in ["cert-rsa.pem", "key-rsa.pem"]: file_path = "%s/%s" % (config.data_dir, file_name) if os.path.isfile(file_path): @@ -59,11 +66,10 @@ class CryptConnectionManager: # Try to create RSA server cert + sign for connection encryption # Return: True on success def createSslRsaCert(self): - import subprocess - if os.path.isfile("%s/cert-rsa.pem" % config.data_dir) and os.path.isfile("%s/key-rsa.pem" % config.data_dir): return True # Files already exits + import subprocess proc = subprocess.Popen( "%s req -x509 -newkey rsa:2048 -sha256 -batch -keyout %s -out %s -nodes -config %s" % helper.shellquote( self.openssl_bin, diff --git a/src/Crypt/CryptHash.py b/src/Crypt/CryptHash.py index e25c06c6..fb0c2dab 100644 --- a/src/Crypt/CryptHash.py +++ b/src/Crypt/CryptHash.py @@ -21,6 +21,15 @@ def sha512sum(file, blocksize=65536): return hash.hexdigest()[0:64] # Truncate to 256bits is good enough +def sha256sum(file, blocksize=65536): + if hasattr(file, "endswith"): # Its a string open it + file = open(file, "rb") + hash = hashlib.sha256() + for block in iter(lambda: file.read(blocksize), ""): + hash.update(block) + return hash.hexdigest() + + def random(length=64, encoding="hex"): if encoding == "base64": # Characters: A-Za-z0-9 hash = hashlib.sha512(os.urandom(256)).digest() diff --git a/src/Crypt/CryptRsa.py b/src/Crypt/CryptRsa.py new file mode 100644 index 00000000..694ef34f --- /dev/null +++ b/src/Crypt/CryptRsa.py @@ -0,0 +1,38 @@ +import base64 +import hashlib + +def sign(data, privatekey): + from lib import rsa + from lib.rsa import pkcs1 + + if "BEGIN RSA PRIVATE KEY" not in privatekey: + privatekey = "-----BEGIN RSA PRIVATE KEY-----\n%s\n-----END RSA PRIVATE KEY-----" % privatekey + + priv = rsa.PrivateKey.load_pkcs1(privatekey) + sign = rsa.pkcs1.sign(data, priv, 'SHA-256') + return sign + +def verify(data, publickey, sign): + from lib import rsa + from lib.rsa import pkcs1 + + pub = rsa.PublicKey.load_pkcs1(publickey, format="DER") + try: + valid = rsa.pkcs1.verify(data, sign, pub) + except pkcs1.VerificationError: + valid = False + return valid + +def privatekeyToPublickey(privatekey): + from lib import rsa + from lib.rsa import pkcs1 + + if "BEGIN RSA PRIVATE KEY" not in privatekey: + privatekey = "-----BEGIN RSA PRIVATE KEY-----\n%s\n-----END RSA PRIVATE KEY-----" % privatekey + + priv = rsa.PrivateKey.load_pkcs1(privatekey) + pub = rsa.PublicKey(priv.n, priv.e) + return pub.save_pkcs1("DER") + +def publickeyToOnion(publickey): + return base64.b32encode(hashlib.sha1(publickey).digest()[:10]).lower() diff --git a/src/Db/Db.py b/src/Db/Db.py index 591ee206..e6c64bec 100644 --- a/src/Db/Db.py +++ b/src/Db/Db.py @@ -22,7 +22,7 @@ def dbCleanup(): gevent.spawn(dbCleanup) -class Db: +class Db(object): def __init__(self, schema, db_path): self.db_path = db_path @@ -34,6 +34,7 @@ class Db: self.log = logging.getLogger("Db:%s" % schema["db_name"]) self.table_names = None self.collect_stats = False + self.foreign_keys = False self.query_stats = {} self.db_keyvalues = {} self.last_query_time = time.time() @@ -59,6 +60,9 @@ class Db: self.cur.execute("PRAGMA journal_mode = WAL") self.cur.execute("PRAGMA journal_mode = MEMORY") self.cur.execute("PRAGMA synchronous = OFF") + if self.foreign_keys: + self.execute("PRAGMA foreign_keys = ON") + # Execute query using dbcursor def execute(self, query, params=None): diff --git a/src/Db/DbCursor.py b/src/Db/DbCursor.py index a34f9157..f3a1c532 100644 --- a/src/Db/DbCursor.py +++ b/src/Db/DbCursor.py @@ -13,17 +13,23 @@ class DbCursor: self.logging = False def execute(self, query, params=None): - if isinstance(params, dict): # Make easier select and insert by allowing dict params - if query.startswith("SELECT") or query.startswith("DELETE"): + if isinstance(params, dict) and "?" in query: # Make easier select and insert by allowing dict params + if query.startswith("SELECT") or query.startswith("DELETE") or query.startswith("UPDATE"): # Convert param dict to SELECT * FROM table WHERE key = ? AND key2 = ? format query_wheres = [] values = [] for key, value in params.items(): if type(value) is list: - query_wheres.append(key+" IN ("+",".join(["?"]*len(value))+")") + if key.startswith("not__"): + query_wheres.append(key.replace("not__", "")+" NOT IN ("+",".join(["?"]*len(value))+")") + else: + query_wheres.append(key+" IN ("+",".join(["?"]*len(value))+")") values += value else: - query_wheres.append(key+" = ?") + if key.startswith("not__"): + query_wheres.append(key.replace("not__", "")+" != ?") + else: + query_wheres.append(key+" = ?") values.append(value) wheres = " AND ".join(query_wheres) query = query.replace("?", wheres) @@ -41,7 +47,7 @@ class DbCursor: if params: # Query has parameters res = self.cursor.execute(query, params) if self.logging: - self.db.log.debug((query.replace("?", "%s") % params) + " (Done in %.4f)" % (time.time() - s)) + self.db.log.debug(query + " " + str(params) + " (Done in %.4f)" % (time.time() - s)) else: res = self.cursor.execute(query) if self.logging: diff --git a/src/File/FileRequest.py b/src/File/FileRequest.py index bfe39a66..42b7d855 100644 --- a/src/File/FileRequest.py +++ b/src/File/FileRequest.py @@ -11,11 +11,13 @@ from Config import config from util import RateLimit from util import StreamingMsgpack from util import helper +from Plugin import PluginManager FILE_BUFF = 1024 * 512 -# Request from me +# Incoming requests +@PluginManager.acceptPlugins class FileRequest(object): __slots__ = ("server", "connection", "req_id", "sites", "log", "responded") @@ -50,36 +52,25 @@ class FileRequest(object): # Route file requests def route(self, cmd, req_id, params): self.req_id = req_id + # Don't allow other sites than locked + if "site" in params and self.connection.site_lock and self.connection.site_lock not in (params["site"], "global"): + self.response({"error": "Invalid site"}) + self.log.error("Site lock violation: %s != %s" % (self.connection.site_lock != params["site"])) + return False - if cmd == "getFile": - self.actionGetFile(params) - elif cmd == "streamFile": - self.actionStreamFile(params) - elif cmd == "update": + if cmd == "update": event = "%s update %s %s" % (self.connection.id, params["site"], params["inner_path"]) if not RateLimit.isAllowed(event): # There was already an update for this file in the last 10 second self.response({"ok": "File update queued"}) # If called more than once within 10 sec only keep the last update RateLimit.callAsync(event, 10, self.actionUpdate, params) - - elif cmd == "pex": - self.actionPex(params) - elif cmd == "listModified": - self.actionListModified(params) - elif cmd == "getHashfield": - self.actionGetHashfield(params) - elif cmd == "findHashIds": - self.actionFindHashIds(params) - elif cmd == "setHashfield": - self.actionSetHashfield(params) - elif cmd == "siteReload": - self.actionSiteReload(params) - elif cmd == "sitePublish": - self.actionSitePublish(params) - elif cmd == "ping": - self.actionPing() else: - self.actionUnknown(cmd, params) + func_name = "action" + cmd[0].upper() + cmd[1:] + func = getattr(self, func_name, None) + if func: + func(params) + else: + self.actionUnknown(cmd, params) # Update a site file request def actionUpdate(self, params): @@ -117,7 +108,10 @@ class FileRequest(object): self.response({"ok": "Thanks, file %s updated!" % params["inner_path"]}) elif valid is None: # Not changed - peer = site.addPeer(*params["peer"], return_peer=True) # Add or get peer + if params.get("peer"): + peer = site.addPeer(*params["peer"], return_peer=True) # Add or get peer + else: + peer = site.addPeer(self.connection.ip, self.connection.port, return_peer=True) # Add or get peer if peer: self.log.debug( "Same version, adding new peer for locked files: %s, tasks: %s" % @@ -148,7 +142,7 @@ class FileRequest(object): file.seek(params["location"]) file.read_bytes = FILE_BUFF file_size = os.fstat(file.fileno()).st_size - assert params["location"] < file_size + assert params["location"] <= file_size, "Bad file location" back = { "body": file, @@ -190,7 +184,7 @@ class FileRequest(object): file.seek(params["location"]) file_size = os.fstat(file.fileno()).st_size stream_bytes = min(FILE_BUFF, file_size - params["location"]) - assert stream_bytes >= 0 + assert stream_bytes >= 0, "Stream bytes out of range" back = { "size": file_size, @@ -236,18 +230,36 @@ class FileRequest(object): connected_peer.connect(self.connection) # Assign current connection to peer # Add sent peers to site - for packed_address in params["peers"]: + for packed_address in params.get("peers", []): address = helper.unpackAddress(packed_address) got_peer_keys.append("%s:%s" % address) if site.addPeer(*address): added += 1 + # Add sent peers to site + for packed_address in params.get("peers_onion", []): + address = helper.unpackOnionAddress(packed_address) + got_peer_keys.append("%s:%s" % address) + if site.addPeer(*address): + added += 1 + # Send back peers that is not in the sent list and connectable (not port 0) - packed_peers = [peer.packMyAddress() for peer in site.getConnectablePeers(params["need"], got_peer_keys)] + packed_peers = helper.packPeers(site.getConnectablePeers(params["need"], got_peer_keys)) + if added: site.worker_manager.onPeers() - self.log.debug("Added %s peers to %s using pex, sending back %s" % (added, site, len(packed_peers))) - self.response({"peers": packed_peers}) + self.log.debug( + "Added %s peers to %s using pex, sending back %s" % + (added, site, len(packed_peers["ip4"]) + len(packed_peers["onion"])) + ) + + back = {} + if packed_peers["ip4"]: + back["peers"] = packed_peers["ip4"] + if packed_peers["onion"]: + back["peers_onion"] = packed_peers["onion"] + + self.response(back) # Get modified content.json files since def actionListModified(self, params): @@ -316,7 +328,7 @@ class FileRequest(object): self.response({"error": "Unknown site"}) return False - peer = site.addPeer(self.connection.ip, self.connection.port, return_peer=True) # Add or get peer + peer = site.addPeer(self.connection.ip, self.connection.port, return_peer=True, connection=self.connection) # Add or get peer if not peer.connection: peer.connect(self.connection) peer.hashfield.replaceFromString(params["hashfield_raw"]) @@ -343,7 +355,7 @@ class FileRequest(object): self.response({"ok": "Successfuly published to %s peers" % num}) # Send a simple Pong! answer - def actionPing(self): + def actionPing(self, params): self.response("Pong!") # Unknown command diff --git a/src/File/FileServer.py b/src/File/FileServer.py index 27e681a1..a232930d 100644 --- a/src/File/FileServer.py +++ b/src/File/FileServer.py @@ -49,9 +49,13 @@ class FileServer(ConnectionServer): if self.port_opened: return True # Port already opened if check: # Check first if its already opened - if self.testOpenport(port)["result"] is True: + time.sleep(1) # Wait for port open + if self.testOpenport(port, use_alternative=False)["result"] is True: return True # Port already opened + if config.tor == "always": # Port opening won't work in Tor mode + return False + self.log.info("Trying to open port using UpnpPunch...") try: upnp_punch = UpnpPunch.open_port(self.port, 'ZeroNet') @@ -67,15 +71,14 @@ class FileServer(ConnectionServer): return False # Test if the port is open - def testOpenport(self, port=None): - time.sleep(1) # Wait for port open + def testOpenport(self, port=None, use_alternative=True): if not port: port = self.port back = self.testOpenportPortchecker(port) - if back["result"] is True: # Successful port check - return back - else: # Alternative port checker + if back["result"] is not True and use_alternative: # If no success try alternative checker return self.testOpenportCanyouseeme(port) + else: + return back def testOpenportPortchecker(self, port=None): self.log.info("Checking port %s using portchecker.co..." % port) @@ -151,16 +154,24 @@ class FileServer(ConnectionServer): # Check site file integrity def checkSite(self, site): if site.settings["serving"]: - site.announce() # Announce site to tracker + site.announce(mode="startup") # Announce site to tracker site.update() # Update site's content.json and download changed files + site.sendMyHashfield() + site.updateHashfield() if self.port_opened is False: # In passive mode keep 5 active peer connection to get the updates site.needConnections() # Check sites integrity def checkSites(self): if self.port_opened is None: # Test and open port if not tested yet + if len(self.sites) <= 2: # Faster announce on first startup + for address, site in self.sites.items(): + gevent.spawn(self.checkSite, site) self.openport() + if not self.port_opened: + self.tor_manager.startOnions() + self.log.debug("Checking sites integrity..") for address, site in self.sites.items(): # Check sites integrity gevent.spawn(self.checkSite, site) # Check in new thread @@ -170,36 +181,30 @@ class FileServer(ConnectionServer): # Announce sites every 20 min def announceSites(self): import gc - first_announce = True # First start while 1: - # Sites healthcare every 20 min + # Sites health care every 20 min if config.trackers_file: config.loadTrackersFile() for address, site in self.sites.items(): - if site.settings["serving"]: - if first_announce: # Announce to all trackers on startup - site.announce() - else: # If not first run only use PEX - site.announcePex() + if not site.settings["serving"]: + continue + if site.peers: + site.announcePex() - # Retry failed files - if site.bad_files: - site.retryBadFiles() + # Retry failed files + if site.bad_files: + site.retryBadFiles() - site.cleanupPeers() + site.cleanupPeers() - # In passive mode keep 5 active peer connection to get the updates - if self.port_opened is False: - site.needConnections() - - if first_announce: # Send my optional files to peers - site.sendMyHashfield() - site.updateHashfield() + # In passive mode keep 5 active peer connection to get the updates + if self.port_opened is False: + site.needConnections() time.sleep(2) # Prevent too quick request site = None - gc.collect() # Implicit grabage collection + gc.collect() # Implicit garbage collection # Find new peers for tracker_i in range(len(config.trackers)): @@ -207,13 +212,15 @@ class FileServer(ConnectionServer): if config.trackers_file: config.loadTrackersFile() for address, site in self.sites.items(): - site.announce(num=1, pex=False) + if not site.settings["serving"]: + continue + site.announce(mode="update", pex=False) + if site.settings["own"]: # Check connections more frequently on own sites to speed-up first connections + site.needConnections() site.sendMyHashfield(3) site.updateHashfield(1) time.sleep(2) - first_announce = False - # Detects if computer back from wakeup def wakeupWatcher(self): last_time = time.time() diff --git a/src/Peer/Peer.py b/src/Peer/Peer.py index a543d581..192dfaa7 100644 --- a/src/Peer/Peer.py +++ b/src/Peer/Peer.py @@ -17,17 +17,18 @@ if config.use_tempfiles: # Communicate remote peers class Peer(object): __slots__ = ( - "ip", "port", "site", "key", "connection", "time_found", "time_response", "time_hashfield", "time_added", + "ip", "port", "site", "key", "connection", "connection_server", "time_found", "time_response", "time_hashfield", "time_added", "time_my_hashfield_sent", "last_ping", "hashfield", "connection_error", "hash_failed", "download_bytes", "download_time" ) - def __init__(self, ip, port, site=None): + def __init__(self, ip, port, site=None, connection_server=None): self.ip = ip self.port = port self.site = site self.key = "%s:%s" % (ip, port) self.connection = None + self.connection_server = connection_server self.hashfield = PeerHashfield() # Got optional files hash_id self.time_hashfield = None # Last time peer's hashfiled downloaded self.time_my_hashfield_sent = None # Last time my hashfield sent to peer @@ -61,10 +62,12 @@ class Peer(object): self.connection = None try: - if self.site: - self.connection = self.site.connection_server.getConnection(self.ip, self.port) + if self.connection_server: + self.connection = self.connection_server.getConnection(self.ip, self.port, site=self.site) + elif self.site: + self.connection = self.site.connection_server.getConnection(self.ip, self.port, site=self.site) else: - self.connection = sys.modules["main"].file_server.getConnection(self.ip, self.port) + self.connection = sys.modules["main"].file_server.getConnection(self.ip, self.port, site=self.site) except Exception, err: self.onConnectionError() @@ -77,7 +80,7 @@ class Peer(object): if self.connection and self.connection.connected: # We have connection to peer return self.connection else: # Try to find from other sites connections - self.connection = self.site.connection_server.getConnection(self.ip, self.port, create=False) + self.connection = self.site.connection_server.getConnection(self.ip, self.port, create=False, site=self.site) return self.connection def __str__(self): @@ -87,7 +90,10 @@ class Peer(object): return "<%s>" % self.__str__() def packMyAddress(self): - return helper.packAddress(self.ip, self.port) + if self.ip.endswith(".onion"): + return helper.packOnionAddress(self.ip, self.port) + else: + return helper.packAddress(self.ip, self.port) # Found a peer on tracker def found(self): @@ -155,7 +161,8 @@ class Peer(object): self.download_bytes += res["location"] self.download_time += (time.time() - s) - self.site.settings["bytes_recv"] = self.site.settings.get("bytes_recv", 0) + res["location"] + if self.site: + self.site.settings["bytes_recv"] = self.site.settings.get("bytes_recv", 0) + res["location"] buff.seek(0) return buff @@ -213,18 +220,30 @@ class Peer(object): def pex(self, site=None, need_num=5): if not site: site = self.site # If no site defined request peers for this site - # give him/her 5 connectible peers - packed_peers = [peer.packMyAddress() for peer in self.site.getConnectablePeers(5)] - res = self.request("pex", {"site": site.address, "peers": packed_peers, "need": need_num}) + + # give back 5 connectible peers + packed_peers = helper.packPeers(self.site.getConnectablePeers(5)) + request = {"site": site.address, "peers": packed_peers["ip4"], "need": need_num} + if packed_peers["onion"]: + request["peers_onion"] = packed_peers["onion"] + res = self.request("pex", request) if not res or "error" in res: return False added = 0 + # Ip4 for peer in res.get("peers", []): address = helper.unpackAddress(peer) if site.addPeer(*address): added += 1 + # Onion + for peer in res.get("peers_onion", []): + address = helper.unpackOnionAddress(peer) + if site.addPeer(*address): + added += 1 + if added: self.log("Added peers using pex: %s" % added) + return added # List modified files since the date diff --git a/src/Site/Site.py b/src/Site/Site.py index 4bf5ff3c..77e74f05 100644 --- a/src/Site/Site.py +++ b/src/Site/Site.py @@ -6,7 +6,6 @@ import re import time import random import sys -import binascii import struct import socket import urllib @@ -25,10 +24,12 @@ from Content import ContentManager from SiteStorage import SiteStorage from Crypt import CryptHash from util import helper +from Plugin import PluginManager import SiteManager -class Site: +@PluginManager.acceptPlugins +class Site(object): def __init__(self, address, allow_create=True): self.address = re.sub("[^A-Za-z0-9]", "", address) # Make sure its correct address @@ -297,6 +298,19 @@ class Site: def publisher(self, inner_path, peers, published, limit, event_done=None): file_size = self.storage.getSize(inner_path) body = self.storage.read(inner_path) + tor_manager = self.connection_server.tor_manager + if tor_manager.enabled and tor_manager.start_onions: + my_ip = tor_manager.getOnion(self.address) + if my_ip: + my_ip += ".onion" + my_port = config.fileserver_port + else: + my_ip = config.ip_external + if self.connection_server.port_opened: + my_port = config.fileserver_port + else: + my_port = 0 + while 1: if not peers or len(published) >= limit: if event_done: @@ -318,7 +332,7 @@ class Site: "site": self.address, "inner_path": inner_path, "body": body, - "peer": (config.ip_external, config.fileserver_port) + "peer": (my_ip, my_port) }) if result: break @@ -499,11 +513,9 @@ class Site: # Add or update a peer to site # return_peer: Always return the peer even if it was already present - def addPeer(self, ip, port, return_peer=False): + def addPeer(self, ip, port, return_peer=False, connection=None): if not ip: return False - if (ip, port) in self.peer_blacklist: - return False # Ignore blacklist (eg. myself) key = "%s:%s" % (ip, port) if key in self.peers: # Already has this ip self.peers[key].found() @@ -512,6 +524,8 @@ class Site: else: return False else: # New peer + if (ip, port) in self.peer_blacklist: + return False # Ignore blacklist (eg. myself) peer = Peer(ip, port, self) self.peers[key] = peer return peer @@ -529,13 +543,7 @@ class Site: done = 0 added = 0 for peer in peers: - if peer.connection: # Has connection - if "port_opened" in peer.connection.handshake: # This field added recently, so probably has has peer exchange - res = peer.pex(need_num=need_num) - else: - res = False - else: # No connection - res = peer.pex(need_num=need_num) + res = peer.pex(need_num=need_num) if type(res) == int: # We have result done += 1 added += res @@ -548,33 +556,36 @@ class Site: # Gather peers from tracker # Return: Complete time or False on error - def announceTracker(self, protocol, address, fileserver_port, address_hash, my_peer_id): + def announceTracker(self, tracker_protocol, tracker_address, fileserver_port=0, add_types=[], my_peer_id="", mode="start"): s = time.time() - if protocol == "udp": # Udp tracker + if "ip4" not in add_types: + fileserver_port = 0 + + if tracker_protocol == "udp": # Udp tracker if config.disable_udp: return False # No udp supported - ip, port = address.split(":") + ip, port = tracker_address.split(":") tracker = UdpTrackerClient(ip, int(port)) tracker.peer_port = fileserver_port try: tracker.connect() tracker.poll_once() - tracker.announce(info_hash=address_hash, num_want=50) + tracker.announce(info_hash=hashlib.sha1(self.address).hexdigest(), num_want=50) back = tracker.poll_once() peers = back["response"]["peers"] except Exception, err: return False - else: # Http tracker + elif tracker_protocol == "http": # Http tracker params = { - 'info_hash': binascii.a2b_hex(address_hash), + 'info_hash': hashlib.sha1(self.address).digest(), 'peer_id': my_peer_id, 'port': fileserver_port, 'uploaded': 0, 'downloaded': 0, 'left': 0, 'compact': 1, 'numwant': 30, 'event': 'started' } req = None try: - url = "http://" + address + "?" + urllib.urlencode(params) + url = "http://" + tracker_address + "?" + urllib.urlencode(params) # Load url with gevent.Timeout(30, False): # Make sure of timeout req = urllib2.urlopen(url, timeout=25) @@ -601,6 +612,8 @@ class Site: req.close() req = None return False + else: + peers = [] # Adding peers added = 0 @@ -616,67 +629,75 @@ class Site: return time.time() - s # Add myself and get other peers from tracker - def announce(self, force=False, num=5, pex=True): + def announce(self, force=False, mode="start", pex=True): if time.time() < self.time_announce + 30 and not force: return # No reannouncing within 30 secs self.time_announce = time.time() + trackers = config.trackers + # Filter trackers based on supported networks if config.disable_udp: - trackers = [tracker for tracker in config.trackers if not tracker.startswith("udp://")] - else: - trackers = config.trackers - if num == 1: # Only announce on one tracker, increment the queried tracker id + trackers = [tracker for tracker in trackers if not tracker.startswith("udp://")] + if not self.connection_server.tor_manager.enabled: + trackers = [tracker for tracker in trackers if ".onion" not in tracker] + + if mode == "update" or mode == "more": # Only announce on one tracker, increment the queried tracker id self.last_tracker_id += 1 self.last_tracker_id = self.last_tracker_id % len(trackers) trackers = [trackers[self.last_tracker_id]] # We only going to use this one errors = [] slow = [] - address_hash = hashlib.sha1(self.address).hexdigest() # Site address hash - my_peer_id = sys.modules["main"].file_server.peer_id + add_types = [] + if self.connection_server: + my_peer_id = self.connection_server.peer_id - if sys.modules["main"].file_server.port_opened: - fileserver_port = config.fileserver_port - else: # Port not opened, report port 0 - fileserver_port = 0 + # Type of addresses they can reach me + if self.connection_server.port_opened: + add_types.append("ip4") + if self.connection_server.tor_manager.enabled and self.connection_server.tor_manager.start_onions: + add_types.append("onion") + else: + my_peer_id = "" s = time.time() announced = 0 threads = [] + fileserver_port = config.fileserver_port for tracker in trackers: # Start announce threads - protocol, address = tracker.split("://") - thread = gevent.spawn(self.announceTracker, protocol, address, fileserver_port, address_hash, my_peer_id) + tracker_protocol, tracker_address = tracker.split("://") + thread = gevent.spawn( + self.announceTracker, tracker_protocol, tracker_address, fileserver_port, add_types, my_peer_id, mode + ) threads.append(thread) - thread.address = address - thread.protocol = protocol - if len(threads) > num: # Announce limit - break + thread.tracker_address = tracker_address + thread.tracker_protocol = tracker_protocol gevent.joinall(threads, timeout=10) # Wait for announce finish for thread in threads: if thread.value: if thread.value > 1: - slow.append("%.2fs %s://%s" % (thread.value, thread.protocol, thread.address)) + slow.append("%.2fs %s://%s" % (thread.value, thread.tracker_protocol, thread.tracker_address)) announced += 1 else: if thread.ready(): - errors.append("%s://%s" % (thread.protocol, thread.address)) + errors.append("%s://%s" % (thread.tracker_protocol, thread.tracker_address)) else: # Still running - slow.append("10s+ %s://%s" % (thread.protocol, thread.address)) + slow.append("10s+ %s://%s" % (thread.tracker_protocol, thread.tracker_address)) # Save peers num self.settings["peers"] = len(self.peers) self.saveSettings() - if len(errors) < min(num, len(trackers)): # Less errors than total tracker nums + if len(errors) < len(threads): # Less errors than total tracker nums self.log.debug( - "Announced port %s to %s trackers in %.3fs, errors: %s, slow: %s" % - (fileserver_port, announced, time.time() - s, errors, slow) + "Announced types %s in mode %s to %s trackers in %.3fs, errors: %s, slow: %s" % + (add_types, mode, announced, time.time() - s, errors, slow) ) else: - if num > 1: + if mode != "update": self.log.error("Announce to %s trackers in %.3fs, failed" % (announced, time.time() - s)) if pex: @@ -684,7 +705,10 @@ class Site: # If no connected peer yet then wait for connections gevent.spawn_later(3, self.announcePex, need_num=10) # Spawn 3 secs later else: # Else announce immediately - self.announcePex() + if mode == "more": # Need more peers + self.announcePex(need_num=10) + else: + self.announcePex() # Keep connections to get the updates (required for passive clients) def needConnections(self, num=3): @@ -726,8 +750,7 @@ class Site: if len(found) >= need_num: break # Found requested number of peers - if (not found and not ignore) or (need_num > 5 and need_num < 100 and len(found) < need_num): - # Return not that good peers: Not found any peer and the requester dont have any or cant give enough peer + if need_num > 5 and need_num < 100 and len(found) < need_num: # Return not that good peers found = [peer for peer in peers if not peer.key.endswith(":0") and peer.key not in ignore][0:need_num - len(found)] return found diff --git a/src/Site/SiteManager.py b/src/Site/SiteManager.py index 3890f8b1..cf18f342 100644 --- a/src/Site/SiteManager.py +++ b/src/Site/SiteManager.py @@ -85,4 +85,4 @@ class SiteManager(object): site_manager = SiteManager() # Singletone -peer_blacklist = [] # Dont download from this peers +peer_blacklist = [("127.0.0.1", config.fileserver_port)] # Dont add this peers diff --git a/src/Test/TestConnectionServer.py b/src/Test/TestConnectionServer.py index 94175ffb..c2dba481 100644 --- a/src/Test/TestConnectionServer.py +++ b/src/Test/TestConnectionServer.py @@ -79,6 +79,8 @@ class TestConnection: def testFloodProtection(self, file_server): file_server.ip_incoming = {} # Reset flood protection + whitelist = file_server.whitelist # Save for reset + file_server.whitelist = [] # Disable 127.0.0.1 whitelist client = ConnectionServer("127.0.0.1", 1545) # Only allow 3 connection in 1 minute @@ -98,3 +100,6 @@ class TestConnection: with pytest.raises(gevent.Timeout): with gevent.Timeout(0.1): connection = client.getConnection("127.0.0.1", 1544) + + # Reset whitelist + file_server.whitelist = whitelist diff --git a/src/Test/TestDb.py b/src/Test/TestDb.py index 55ae103d..97a165f2 100644 --- a/src/Test/TestDb.py +++ b/src/Test/TestDb.py @@ -112,6 +112,9 @@ class TestDb: assert db.execute("SELECT COUNT(*) AS num FROM test WHERE ?", {"test_id": [1,2,3], "title": "Test #2"}).fetchone()["num"] == 1 assert db.execute("SELECT COUNT(*) AS num FROM test WHERE ?", {"test_id": [1,2,3], "title": ["Test #2", "Test #3", "Test #4"]}).fetchone()["num"] == 2 + # Test named parameter escaping + assert db.execute("SELECT COUNT(*) AS num FROM test WHERE test_id = :test_id AND title LIKE :titlelike", {"test_id": 1, "titlelike": "Test%"}).fetchone()["num"] == 1 + db.close() # Cleanup diff --git a/src/Test/TestHelper.py b/src/Test/TestHelper.py index 3b4a196f..28f7f6fb 100644 --- a/src/Test/TestHelper.py +++ b/src/Test/TestHelper.py @@ -18,7 +18,7 @@ class TestHelper: with pytest.raises(socket.error): helper.packAddress("999.1.1.1", 1) - with pytest.raises(socket.error): + with pytest.raises(AssertionError): helper.unpackAddress("X") def testGetDirname(self): diff --git a/src/Test/TestTor.py b/src/Test/TestTor.py new file mode 100644 index 00000000..2cdbb9e4 --- /dev/null +++ b/src/Test/TestTor.py @@ -0,0 +1,107 @@ +import pytest +import time + +from File import FileServer +from Crypt import CryptRsa + +@pytest.mark.usefixtures("resetSettings") +@pytest.mark.usefixtures("resetTempSettings") +class TestTor: + def testDownload(self, tor_manager): + for retry in range(15): + time.sleep(1) + if tor_manager.enabled and tor_manager.conn: + break + assert tor_manager.enabled + + def testManagerConnection(self, tor_manager): + assert "250-version" in tor_manager.request("GETINFO version") + + def testAddOnion(self, tor_manager): + # Add + address = tor_manager.addOnion() + assert address + assert address in tor_manager.privatekeys + + # Delete + assert tor_manager.delOnion(address) + assert address not in tor_manager.privatekeys + + def testSignOnion(self, tor_manager): + address = tor_manager.addOnion() + + # Sign + sign = CryptRsa.sign("hello", tor_manager.getPrivatekey(address)) + assert len(sign) == 128 + + # Verify + publickey = CryptRsa.privatekeyToPublickey(tor_manager.getPrivatekey(address)) + assert len(publickey) == 140 + assert CryptRsa.verify("hello", publickey, sign) + assert not CryptRsa.verify("not hello", publickey, sign) + + # Pub to address + assert CryptRsa.publickeyToOnion(publickey) == address + + # Delete + tor_manager.delOnion(address) + + @pytest.mark.skipif(not pytest.config.getvalue("slow"), reason="--slow not requested (takes around ~ 1min)") + def testConnection(self, tor_manager, file_server, site, site_temp): + file_server.tor_manager.start_onions = True + address = file_server.tor_manager.getOnion(site.address) + assert address + print "Connecting to", address + for retry in range(5): # Wait for hidden service creation + time.sleep(10) + try: + connection = file_server.getConnection(address+".onion", 1544) + if connection: + break + except Exception, err: + continue + assert connection.handshake + assert not connection.handshake["peer_id"] # No peer_id for Tor connections + + # Return the same connection without site specified + assert file_server.getConnection(address+".onion", 1544) == connection + # No reuse for different site + assert file_server.getConnection(address+".onion", 1544, site=site) != connection + assert file_server.getConnection(address+".onion", 1544, site=site) == file_server.getConnection(address+".onion", 1544, site=site) + site_temp.address = "1OTHERSITE" + assert file_server.getConnection(address+".onion", 1544, site=site) != file_server.getConnection(address+".onion", 1544, site=site_temp) + + # Only allow to query from the locked site + file_server.sites[site.address] = site + connection_locked = file_server.getConnection(address+".onion", 1544, site=site) + assert "body" in connection_locked.request("getFile", {"site": site.address, "inner_path": "content.json", "location": 0}) + assert connection_locked.request("getFile", {"site": "1OTHERSITE", "inner_path": "content.json", "location": 0})["error"] == "Invalid site" + + def testPex(self, file_server, site, site_temp): + # Register site to currently running fileserver + site.connection_server = file_server + file_server.sites[site.address] = site + # Create a new file server to emulate new peer connecting to our peer + file_server_temp = FileServer("127.0.0.1", 1545) + site_temp.connection_server = file_server_temp + file_server_temp.sites[site_temp.address] = site_temp + # We will request peers from this + peer_source = site_temp.addPeer("127.0.0.1", 1544) + + # Get ip4 peers from source site + assert peer_source.pex(need_num=10) == 1 # Need >5 to return also return non-connected peers + assert len(site_temp.peers) == 2 # Me, and the other peer + site.addPeer("1.2.3.4", 1555) # Add peer to source site + assert peer_source.pex(need_num=10) == 1 + assert len(site_temp.peers) == 3 + assert "1.2.3.4:1555" in site_temp.peers + + # Get onion peers from source site + site.addPeer("bka4ht2bzxchy44r.onion", 1555) + assert "bka4ht2bzxchy44r.onion:1555" not in site_temp.peers + assert peer_source.pex(need_num=10) == 1 # Need >5 to return also return non-connected peers + assert "bka4ht2bzxchy44r.onion:1555" in site_temp.peers + + def testSiteOnion(self, tor_manager): + assert tor_manager.getOnion("address1") != tor_manager.getOnion("address2") + assert tor_manager.getOnion("address1") == tor_manager.getOnion("address1") diff --git a/src/Test/conftest.py b/src/Test/conftest.py index 80da90ba..9501ea32 100644 --- a/src/Test/conftest.py +++ b/src/Test/conftest.py @@ -8,6 +8,10 @@ import json import pytest import mock + +def pytest_addoption(parser): + parser.addoption("--slow", action='store_true', default=False, help="Also run slow tests") + # Config if sys.platform == "win32": PHANTOMJS_PATH = "tools/phantomjs/bin/phantomjs.exe" @@ -15,29 +19,31 @@ else: PHANTOMJS_PATH = "phantomjs" SITE_URL = "http://127.0.0.1:43110" -# Imports relative to src dir -sys.path.append( - os.path.abspath(os.path.dirname(__file__) + "/..") -) +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__) + "/../lib")) # External modules directory +sys.path.insert(0, os.path.abspath(os.path.dirname(__file__) + "/..")) # Imports relative to src dir + from Config import config config.argv = ["none"] # Dont pass any argv to config parser config.parse() config.data_dir = "src/Test/testdata" # Use test data for unittests config.debug_socket = True # Use test data for unittests +config.tor = "disabled" # Don't start Tor client logging.basicConfig(level=logging.DEBUG, stream=sys.stdout) from Plugin import PluginManager PluginManager.plugin_manager.loadPlugins() +import gevent +from gevent import monkey +monkey.patch_all(thread=False) + from Site import Site from User import UserManager from File import FileServer from Connection import ConnectionServer from Crypt import CryptConnection from Ui import UiWebsocket -import gevent -from gevent import monkey -monkey.patch_all(thread=False) +from Tor import TorManager @pytest.fixture(scope="session") @@ -128,7 +134,6 @@ def site_url(): @pytest.fixture(scope="session") def file_server(request): - CryptConnection.manager.loadCerts() # Load and create certs request.addfinalizer(CryptConnection.manager.removeCerts) # Remove cert files after end file_server = FileServer("127.0.0.1", 1544) gevent.spawn(lambda: ConnectionServer.start(file_server)) @@ -160,3 +165,14 @@ def ui_websocket(site, file_server, user): ui_websocket.testAction = testAction return ui_websocket + + +@pytest.fixture(scope="session") +def tor_manager(): + try: + tor_manager = TorManager() + tor_manager.connect() + tor_manager.startOnions() + except Exception, err: + raise pytest.skip("Test requires Tor with ControlPort: %s, %s" % (config.tor_controller, err)) + return tor_manager diff --git a/src/Tor/TorManager.py b/src/Tor/TorManager.py new file mode 100644 index 00000000..1ed3e476 --- /dev/null +++ b/src/Tor/TorManager.py @@ -0,0 +1,274 @@ +import logging +import re +import socket +import binascii +import sys +import os +import time + +import gevent +import subprocess +import atexit + +from Config import config +from Crypt import CryptRsa +from Site import SiteManager +from lib.PySocks import socks +from gevent.coros import RLock +from util import helper +from Debug import Debug + + +class TorManager: + def __init__(self, fileserver_ip=None, fileserver_port=None): + self.privatekeys = {} # Onion: Privatekey + self.site_onions = {} # Site address: Onion + self.tor_exe = "tools/tor/tor.exe" + self.tor_process = None + self.log = logging.getLogger("TorManager") + self.start_onions = None + self.conn = None + self.lock = RLock() + + if config.tor == "disable": + self.enabled = False + self.start_onions = False + self.status = "Disabled" + else: + self.enabled = True + self.status = "Waiting" + + if fileserver_port: + self.fileserver_port = fileserver_port + else: + self.fileserver_port = config.fileserver_port + + self.ip, self.port = config.tor_controller.split(":") + self.port = int(self.port) + + self.proxy_ip, self.proxy_port = config.tor_proxy.split(":") + self.proxy_port = int(self.proxy_port) + + # Test proxy port + if config.tor != "disable": + try: + if "socket_noproxy" in dir(socket): # Socket proxy-patched, use non-proxy one + self.log.debug("Socket proxy patched, using original") + conn = socket.socket_noproxy(socket.AF_INET, socket.SOCK_STREAM) + else: + conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + conn.settimeout(1) + conn.connect((self.proxy_ip, self.proxy_port)) + self.log.debug("Tor proxy port %s check ok" % config.tor_proxy) + except Exception, err: + self.log.debug("Tor proxy port %s check error: %s" % (config.tor_proxy, err)) + self.enabled = False + # Change to self-bundled Tor ports + from lib.PySocks import socks + self.port = 49051 + self.proxy_port = 49050 + socks.setdefaultproxy(socks.PROXY_TYPE_SOCKS5, "127.0.0.1", self.proxy_port) + if os.path.isfile(self.tor_exe): # Already, downloaded: sync mode + self.startTor() + else: # Not downloaded yet: Async mode + gevent.spawn(self.startTor) + + def startTor(self): + if sys.platform.startswith("win"): + try: + if not os.path.isfile(self.tor_exe): + self.downloadTor() + + self.log.info("Starting Tor client %s..." % self.tor_exe) + tor_dir = os.path.dirname(self.tor_exe) + self.tor_process = subprocess.Popen(r"%s -f torrc" % self.tor_exe, cwd=tor_dir, close_fds=True) + for wait in range(1,10): # Wait for startup + time.sleep(wait * 0.5) + self.enabled = True + if self.connect(): + break + # Terminate on exit + atexit.register(self.stopTor) + except Exception, err: + self.log.error("Error starting Tor client: %s" % Debug.formatException(err)) + self.enabled = False + return False + + def stopTor(self): + self.log.debug("Stopping...") + self.tor_process.terminate() + + def downloadTor(self): + self.log.info("Downloading Tor...") + # Check Tor webpage for link + download_page = helper.httpRequest("https://www.torproject.org/download/download.html").read() + download_url = re.search('href="(.*?tor.*?win32.*?zip)"', download_page).group(1) + if not download_url.startswith("http"): + download_url = "https://www.torproject.org/download/" + download_url + + # Download Tor client + self.log.info("Downloading %s" % download_url) + data = helper.httpRequest(download_url, as_file=True) + data_size = data.tell() + + # Handle redirect + if data_size < 1024 and "The document has moved" in data.getvalue(): + download_url = re.search('href="(.*?tor.*?win32.*?zip)"', data.getvalue()).group(1) + data = helper.httpRequest(download_url, as_file=True) + data_size = data.tell() + + if data_size > 1024: + import zipfile + zip = zipfile.ZipFile(data) + self.log.info("Unpacking Tor") + for inner_path in zip.namelist(): + if ".." in inner_path: + continue + dest_path = inner_path + dest_path = re.sub("^Data/Tor/", "tools/tor/data/", dest_path) + dest_path = re.sub("^Data/", "tools/tor/data/", dest_path) + dest_path = re.sub("^Tor/", "tools/tor/", dest_path) + dest_dir = os.path.dirname(dest_path) + if dest_dir and not os.path.isdir(dest_dir): + os.makedirs(dest_dir) + + if dest_dir != dest_path.strip("/"): + data = zip.read(inner_path) + if not os.path.isfile(dest_path): + open(dest_path, 'wb').write(data) + else: + self.log.error("Bad response from server: %s" % data.getvalue()) + return False + + def connect(self): + if not self.enabled: + return False + self.site_onions = {} + self.privatekeys = {} + + if "socket_noproxy" in dir(socket): # Socket proxy-patched, use non-proxy one + conn = socket.socket_noproxy(socket.AF_INET, socket.SOCK_STREAM) + else: + conn = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + + self.log.debug("Connecting to %s:%s" % (self.ip, self.port)) + try: + with self.lock: + conn.connect((self.ip, self.port)) + res_protocol = self.send("PROTOCOLINFO", conn) + + version = re.search('Tor="([0-9\.]+)"', res_protocol).group(1) + # Version 0.2.7.5 required because ADD_ONION support + assert int(version.replace(".", "0")) >= 20705, "Tor version >=0.2.7.5 required" + + # Auth cookie file + cookie_match = re.search('COOKIEFILE="(.*?)"', res_protocol) + if cookie_match: + cookie_file = cookie_match.group(1) + auth_hex = binascii.b2a_hex(open(cookie_file, "rb").read()) + res_auth = self.send("AUTHENTICATE %s" % auth_hex, conn) + else: + res_auth = self.send("AUTHENTICATE", conn) + + assert "250 OK" in res_auth, "Authenticate error %s" % res_auth + self.status = "Connected (%s)" % res_auth + self.conn = conn + except Exception, err: + self.conn = None + self.status = "Error (%s)" % err + self.log.error("Tor controller connect error: %s" % err) + self.enabled = False + return self.conn + + def disconnect(self): + self.conn.close() + self.conn = None + + def startOnions(self): + self.log.debug("Start onions") + self.start_onions = True + + # Get new exit node ip + def resetCircuits(self): + res = self.request("SIGNAL NEWNYM") + if "250 OK" not in res: + self.status = "Reset circuits error (%s)" % res + self.log.error("Tor reset circuits error: %s" % res) + + def addOnion(self): + res = self.request("ADD_ONION NEW:RSA1024 port=%s" % self.fileserver_port) + match = re.search("ServiceID=([A-Za-z0-9]+).*PrivateKey=RSA1024:(.*?)[\r\n]", res, re.DOTALL) + if match: + onion_address, onion_privatekey = match.groups() + self.privatekeys[onion_address] = onion_privatekey + self.status = "OK (%s onion running)" % len(self.privatekeys) + SiteManager.peer_blacklist.append((onion_address + ".onion", self.fileserver_port)) + return onion_address + else: + self.status = "AddOnion error (%s)" % res + self.log.error("Tor addOnion error: %s" % res) + return False + + def delOnion(self, address): + res = self.request("DEL_ONION %s" % address) + if "250 OK" in res: + del self.privatekeys[address] + self.status = "OK (%s onion running)" % len(self.privatekeys) + return True + else: + self.status = "DelOnion error (%s)" % res + self.log.error("Tor delOnion error: %s" % res) + self.disconnect() + return False + + def request(self, cmd): + with self.lock: + if not self.enabled: + return False + if not self.conn: + if not self.connect(): + return "" + return self.send(cmd) + + def send(self, cmd, conn=None): + if not conn: + conn = self.conn + self.log.debug("> %s" % cmd) + conn.send("%s\r\n" % cmd) + back = conn.recv(1024 * 64) + self.log.debug("< %s" % back.strip()) + return back + + def getPrivatekey(self, address): + return self.privatekeys[address] + + def getPublickey(self, address): + return CryptRsa.privatekeyToPublickey(self.privatekeys[address]) + + def getOnion(self, site_address): + with self.lock: + if not self.enabled: + return None + if self.start_onions: # Different onion for every site + onion = self.site_onions.get(site_address) + else: # Same onion for every site + onion = self.site_onions.get("global") + site_address = "global" + if not onion: + self.site_onions[site_address] = self.addOnion() + onion = self.site_onions[site_address] + self.log.debug("Created new hidden service for %s: %s" % (site_address, onion)) + return onion + + def createSocket(self, onion, port): + if not self.enabled: + return False + self.log.debug("Creating new socket to %s:%s" % (onion, port)) + if config.tor == "always": # Every socket is proxied by default + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.connect((onion, int(port))) + else: + sock = socks.socksocket() + sock.set_proxy(socks.SOCKS5, self.proxy_ip, self.proxy_port) + sock.connect((onion, int(port))) + return sock diff --git a/src/Tor/__init__.py b/src/Tor/__init__.py new file mode 100644 index 00000000..250eac2d --- /dev/null +++ b/src/Tor/__init__.py @@ -0,0 +1 @@ +from TorManager import TorManager \ No newline at end of file diff --git a/src/Ui/UiRequest.py b/src/Ui/UiRequest.py index 674de84d..9c93540e 100644 --- a/src/Ui/UiRequest.py +++ b/src/Ui/UiRequest.py @@ -140,6 +140,7 @@ class UiRequest(object): headers.append(("Connection", "Keep-Alive")) headers.append(("Keep-Alive", "max=25, timeout=30")) headers.append(("Access-Control-Allow-Origin", "*")) # Allow json access + # headers.append(("Content-Security-Policy", "default-src 'self' data: 'unsafe-inline' ws://127.0.0.1:* http://127.0.0.1:* wss://tracker.webtorrent.io; sandbox allow-same-origin allow-top-navigation allow-scripts")) # Only local connections if self.env["REQUEST_METHOD"] == "OPTIONS": # Allow json access headers.append(("Access-Control-Allow-Headers", "Origin, X-Requested-With, Content-Type, Accept, Cookie")) diff --git a/src/Ui/UiWebsocket.py b/src/Ui/UiWebsocket.py index 04f728fe..7124b575 100644 --- a/src/Ui/UiWebsocket.py +++ b/src/Ui/UiWebsocket.py @@ -33,23 +33,54 @@ class UiWebsocket(object): ws = self.ws if self.site.address == config.homepage and not self.site.page_requested: # Add open fileserver port message or closed port error to homepage at first request after start - if sys.modules["main"].file_server.port_opened is True: + self.site.page_requested = True # Dont add connection notification anymore + file_server = sys.modules["main"].file_server + if file_server.port_opened is None or file_server.tor_manager.start_onions is None: + self.site.page_requested = False # Not ready yet, check next time + elif file_server.port_opened is True: self.site.notifications.append([ "done", "Congratulation, your port %s is opened.
You are full member of ZeroNet network!" % config.fileserver_port, 10000 ]) - elif sys.modules["main"].file_server.port_opened is False: + elif config.tor == "always" and file_server.tor_manager.start_onions: + self.site.notifications.append([ + "done", + """ + Tor mode active, every connection using Onion route.
+ Successfully started Tor onion hidden services. + """, + 10000 + ]) + elif config.tor == "always" and not file_server.tor_manager.start_onions == False: self.site.notifications.append([ "error", """ - Your network connection is restricted. Please, open %s port
- on your router to become full member of ZeroNet network. + Tor mode active, every connection using Onion route.
+ Unable to start hidden services, please check your config. + """, + 0 + ]) + elif file_server.port_opened is False and file_server.tor_manager.start_onions: + self.site.notifications.append([ + "done", + """ + Successfully started Tor onion hidden services.
+ For faster connections open %s port on your router. + """ % config.fileserver_port, + 10000 + ]) + else: + self.site.notifications.append([ + "error", + """ + Your connection is restricted. Please, open %s port on your router
+ or configure Tor to become full member of ZeroNet network. """ % config.fileserver_port, 0 ]) - self.site.page_requested = True # Dont add connection notification anymore + for notification in self.site.notifications: # Send pending notification messages self.cmd("notification", notification) @@ -194,6 +225,8 @@ class UiWebsocket(object): "platform": sys.platform, "fileserver_ip": config.fileserver_ip, "fileserver_port": config.fileserver_port, + "tor_enabled": sys.modules["main"].file_server.tor_manager.enabled, + "tor_status": sys.modules["main"].file_server.tor_manager.status, "ui_ip": config.ui_ip, "ui_port": config.ui_port, "version": config.version, diff --git a/src/User/User.py b/src/User/User.py index 95c0661d..d1d81c17 100644 --- a/src/User/User.py +++ b/src/User/User.py @@ -125,7 +125,8 @@ class User(object): if domain: site_data["cert"] = domain else: - del site_data["cert"] + if "cert" in site_data: + del site_data["cert"] self.save() return site_data diff --git a/src/Worker/WorkerManager.py b/src/Worker/WorkerManager.py index 379ac069..c4da30ff 100644 --- a/src/Worker/WorkerManager.py +++ b/src/Worker/WorkerManager.py @@ -68,7 +68,7 @@ class WorkerManager: "Task taking more than 15 secs, workers: %s find more peers: %s" % (len(workers), task["inner_path"]) ) - task["site"].announce(num=1) # Find more peers + task["site"].announce(mode="more") # Find more peers if task["optional_hash_id"]: self.startFindOptional() else: diff --git a/src/lib/opensslVerify/opensslVerify.py b/src/lib/opensslVerify/opensslVerify.py index 5294816f..98378e3b 100644 --- a/src/lib/opensslVerify/opensslVerify.py +++ b/src/lib/opensslVerify/opensslVerify.py @@ -447,6 +447,7 @@ if __name__ == "__main__": sys.path.append("..") from pybitcointools import bitcoin as btctools print "OpenSSL version %s" % openssl_version + print ssl._lib priv = "5JsunC55XGVqFQj5kPGK4MWgTL26jKbnPhjnmchSNPo75XXCwtk" address = "1N2XWu5soeppX2qUjvrf81rpdbShKJrjTr" sign = btctools.ecdsa_sign("hello", priv) # HGbib2kv9gm9IJjDt1FXbXFczZi35u0rZR3iPUIt5GglDDCeIQ7v8eYXVNIaLoJRI4URGZrhwmsYQ9aVtRTnTfQ= diff --git a/src/lib/pyasn1/CHANGES b/src/lib/pyasn1/CHANGES new file mode 100644 index 00000000..561dedd8 --- /dev/null +++ b/src/lib/pyasn1/CHANGES @@ -0,0 +1,278 @@ +Revision 0.1.7 +-------------- + +- License updated to vanilla BSD 2-Clause to ease package use + (http://opensource.org/licenses/BSD-2-Clause). +- Test suite made discoverable by unittest/unittest2 discovery feature. +- Fix to decoder working on indefinite length substrate -- end-of-octets + marker is now detected by both tag and value. Otherwise zero values may + interfere with end-of-octets marker. +- Fix to decoder to fail in cases where tagFormat indicates inappropriate + format for the type (e.g. BOOLEAN is always PRIMITIVE, SET is always + CONSTRUCTED and OCTET STRING is either of the two) +- Fix to REAL type encoder to force primitive encoding form encoding. +- Fix to CHOICE decoder to handle explicitly tagged, indefinite length + mode encoding +- Fix to REAL type decoder to handle negative REAL values correctly. Test + case added. + +Revision 0.1.6 +-------------- + +- The compact (valueless) way of encoding zero INTEGERs introduced in + 0.1.5 seems to fail miserably as the world is filled with broken + BER decoders. So we had to back off the *encoder* for a while. + There's still the IntegerEncoder.supportCompactZero flag which + enables compact encoding form whenever it evaluates to True. +- Report package version on debugging code initialization. + +Revision 0.1.5 +-------------- + +- Documentation updated and split into chapters to better match + web-site contents. +- Make prettyPrint() working for non-initialized pyasn1 data objects. It + used to throw an exception. +- Fix to encoder to produce empty-payload INTEGER values for zeros +- Fix to decoder to support empty-payload INTEGER and REAL values +- Fix to unit test suites imports to be able to run each from + their current directory + +Revision 0.1.4 +-------------- + +- Built-in codec debugging facility added +- Added some more checks to ObjectIdentifier BER encoder catching + posible 2^8 overflow condition by two leading sub-OIDs +- Implementations overriding the AbstractDecoder.valueDecoder method + changed to return the rest of substrate behind the item being processed + rather than the unprocessed substrate within the item (which is usually + empty). +- Decoder's recursiveFlag feature generalized as a user callback function + which is passed an uninitialized object recovered from substrate and + its uninterpreted payload. +- Catch inappropriate substrate type passed to decoder. +- Expose tagMap/typeMap/Decoder objects at DER decoder to uniform API. +- Obsolete __init__.MajorVersionId replaced with __init__.__version__ + which is now in-sync with distutils. +- Package classifiers updated. +- The __init__.py's made non-empty (rumors are that they may be optimized + out by package managers). +- Bail out gracefully whenever Python version is older than 2.4. +- Fix to Real codec exponent encoding (should be in 2's complement form), + some more test cases added. +- Fix in Boolean truth testing built-in methods +- Fix to substrate underrun error handling at ObjectIdentifier BER decoder +- Fix to BER Boolean decoder that allows other pre-computed + values besides 0 and 1 +- Fix to leading 0x80 octet handling in DER/CER/DER ObjectIdentifier decoder. + See http://www.cosic.esat.kuleuven.be/publications/article-1432.pdf + +Revision 0.1.3 +-------------- + +- Include class name into asn1 value constraint violation exception. +- Fix to OctetString.prettyOut() method that looses leading zero when + building hex string. + +Revision 0.1.2 +-------------- + +- Fix to __long__() to actually return longs on py2k +- Fix to OctetString.__str__() workings of a non-initialized object. +- Fix to quote initializer of OctetString.__repr__() +- Minor fix towards ObjectIdentifier.prettyIn() reliability +- ObjectIdentifier.__str__() is aliased to prettyPrint() +- Exlicit repr() calls replaced with '%r' + +Revision 0.1.1 +-------------- + +- Hex/bin string initializer to OctetString object reworked + (in a backward-incompatible manner) +- Fixed float() infinity compatibility issue (affects 2.5 and earlier) +- Fixed a bug/typo at Boolean CER encoder. +- Major overhawl for Python 2.4 -- 3.2 compatibility: + + get rid of old-style types + + drop string module usage + + switch to rich comparation + + drop explicit long integer type use + + map()/filter() replaced with list comprehension + + apply() replaced with */**args + + switched to use 'key' sort() callback function + + support both __nonzero__() and __bool__() methods + + modified not to use py3k-incompatible exception syntax + + getslice() operator fully replaced with getitem() + + dictionary operations made 2K/3K compatible + + base type for encoding substrate and OctetString-based types + is now 'bytes' when running py3k and 'str' otherwise + + OctetString and derivatives now unicode compliant. + + OctetString now supports two python-neutral getters: asOcts() & asInts() + + print OctetString content in hex whenever it is not printable otherwise + + in test suite, implicit relative import replaced with the absolute one + + in test suite, string constants replaced with numerics + +Revision 0.0.13 +--------------- + +- Fix to base10 normalization function that loops on univ.Real(0) + +Revision 0.0.13b +---------------- + +- ASN.1 Real type is now supported properly. +- Objects of Constructed types now support __setitem__() +- Set/Sequence objects can now be addressed by their field names (string index) + and position (integer index). +- Typo fix to ber.SetDecoder code that prevented guided decoding operation. +- Fix to explicitly tagged items decoding support. +- Fix to OctetString.prettyPrint() to better handle non-printable content. +- Fix to repr() workings of Choice objects. + +Revision 0.0.13a +---------------- + +- Major codec re-design. +- Documentation significantly improved. +- ASN.1 Any type is now supported. +- All example ASN.1 modules moved to separate pyasn1-modules package. +- Fix to initial sub-OID overflow condition detection an encoder. +- BitString initialization value verification improved. +- The Set/Sequence.getNameByPosition() method implemented. +- Fix to proper behaviour of PermittedAlphabetConstraint object. +- Fix to improper Boolean substrate handling at CER/DER decoders. +- Changes towards performance improvement: + + all dict.has_key() & dict.get() invocations replaced with modern syntax + (this breaks compatibility with Python 2.1 and older). + + tag and tagset caches introduced to decoder + + decoder code improved to prevent unnecessary pyasn1 objects creation + + allow disabling components verification when setting components to + structured types, this is used by decoder whilst running in guided mode. + + BER decoder for integer values now looks up a small set of pre-computed + substrate values to save on decoding. + + a few pre-computed values configured to ObjectIdentifier BER encoder. + + ChoiceDecoder split-off SequenceOf one to save on unnecessary checks. + + replace slow hasattr()/getattr() calls with isinstance() introspection. + + track the number of initialized components of Constructed types to save + on default/optional components initialization. + + added a shortcut ObjectIdentifier.asTuple() to be used instead of + __getitem__() in hotspots. + + use Tag.asTuple() and pure integers at tag encoder. + + introduce and use in decoder the baseTagSet attribute of the built-in + ASN.1 types. + +Revision 0.0.12a +---------------- + +- The individual tag/length/value processing methods of + encoder.AbstractItemEncoder renamed (leading underscore stripped) + to promote overloading in cases where partial substrate processing + is required. +- The ocsp.py, ldap.py example scripts added. +- Fix to univ.ObjectIdentifier input value handler to disallow negative + sub-IDs. + +Revision 0.0.11a +---------------- + +- Decoder can now treat values of unknown types as opaque OctetString. +- Fix to Set/SetOf type decoder to handle uninitialized scalar SetOf + components correctly. + +Revision 0.0.10a +---------------- + +- API versioning mechanics retired (pyasn1.v1 -> pyasn1) what makes + it possible to zip-import pyasn1 sources (used by egg and py2exe). + +Revision 0.0.9a +--------------- + +- Allow any non-zero values in Boolean type BER decoder, as it's in + accordnance with the standard. + +Revision 0.0.8a +--------------- + +- Integer.__index__() now supported (for Python 2.5+). +- Fix to empty value encoding in BitString encoder, test case added. +- Fix to SequenceOf decoder that prevents it skipping possible Choice + typed inner component. +- Choice.getName() method added for getting currently set component + name. +- OctetsString.prettyPrint() does a single str() against its value + eliminating an extra quotes. + +Revision 0.0.7a +--------------- + +- Large tags (>31) now supported by codecs. +- Fix to encoder to properly handle explicitly tagged untagged items. +- All possible value lengths (up to 256^126) now supported by encoders. +- Fix to Tag class constructor to prevent negative IDs. + +Revision 0.0.6a +--------------- + +- Make use of setuptools. +- Constraints derivation verification (isSuperTypeOf()/isSubTypeOf()) fixed. +- Fix to constraints comparation logic -- can't cmp() hash values as it + may cause false positives due to hash conflicts. + +Revision 0.0.5a +--------------- + +- Integer BER codec reworked fixing negative values encoding bug. +- clone() and subtype() methods of Constructed ASN.1 classes now + accept optional cloneValueFlag flag which controls original value + inheritance. The default is *not* to inherit original value for + performance reasons (this may affect backward compatibility). + Performance penalty may be huge on deeply nested Constructed objects + re-creation. +- Base ASN.1 types (pyasn1.type.univ.*) do not have default values + anymore. They remain uninitialized acting as ASN.1 types. In + this model, initialized ASN.1 types represent either types with + default value installed or a type instance. +- Decoders' prototypes are now class instances rather than classes. + This is to simplify initial value installation to decoder's + prototype value. +- Bugfix to BitString BER decoder (trailing bits not regarded). +- Bugfix to Constraints use as mapping keys. +- Bugfix to Integer & BitString clone() methods +- Bugix to the way to distinguish Set from SetOf at CER/DER SetOfEncoder +- Adjustments to make it running on Python 1.5. +- In tests, substrate constants converted from hex escaped literals into + octals to overcome indefinite hex width issue occuring in young Python. +- Minor performance optimization of TagSet.isSuperTagSetOf() method +- examples/sshkey.py added + +Revision 0.0.4a +--------------- + +* Asn1ItemBase.prettyPrinter() -> *.prettyPrint() + +Revision 0.0.3a +--------------- + +* Simple ASN1 objects now hash to their Python value and don't + depend upon tag/constraints/etc. +* prettyIn & prettyOut methods of SimplleAsn1Object become public +* many syntax fixes + +Revision 0.0.2a +--------------- + +* ConstraintsIntersection.isSuperTypeOf() and + ConstraintsIntersection.hasConstraint() implemented +* Bugfix to NamedValues initialization code +* +/- operators added to NamedValues objects +* Integer.__abs__() & Integer.subtype() added +* ObjectIdentifier.prettyOut() fixes +* Allow subclass components at SequenceAndSetBase +* AbstractConstraint.__cmp__() dropped +* error.Asn1Error replaced with error.PyAsn1Error + +Revision 0.0.1a +--------------- + +* Initial public alpha release diff --git a/src/lib/pyasn1/LICENSE b/src/lib/pyasn1/LICENSE new file mode 100644 index 00000000..fac589b8 --- /dev/null +++ b/src/lib/pyasn1/LICENSE @@ -0,0 +1,24 @@ +Copyright (c) 2005-2013, Ilya Etingof +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + * Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE. diff --git a/src/lib/pyasn1/PKG-INFO b/src/lib/pyasn1/PKG-INFO new file mode 100644 index 00000000..5de78ece --- /dev/null +++ b/src/lib/pyasn1/PKG-INFO @@ -0,0 +1,26 @@ +Metadata-Version: 1.0 +Name: pyasn1 +Version: 0.1.7 +Summary: ASN.1 types and codecs +Home-page: http://sourceforge.net/projects/pyasn1/ +Author: Ilya Etingof +Author-email: ilya@glas.net +License: BSD +Description: A pure-Python implementation of ASN.1 types and DER/BER/CER codecs (X.208). +Platform: any +Classifier: Development Status :: 5 - Production/Stable +Classifier: Environment :: Console +Classifier: Intended Audience :: Developers +Classifier: Intended Audience :: Education +Classifier: Intended Audience :: Information Technology +Classifier: Intended Audience :: Science/Research +Classifier: Intended Audience :: System Administrators +Classifier: Intended Audience :: Telecommunications Industry +Classifier: License :: OSI Approved :: BSD License +Classifier: Natural Language :: English +Classifier: Operating System :: OS Independent +Classifier: Programming Language :: Python :: 2 +Classifier: Programming Language :: Python :: 3 +Classifier: Topic :: Communications +Classifier: Topic :: Security :: Cryptography +Classifier: Topic :: Software Development :: Libraries :: Python Modules diff --git a/src/lib/pyasn1/README b/src/lib/pyasn1/README new file mode 100644 index 00000000..ffa3b57e --- /dev/null +++ b/src/lib/pyasn1/README @@ -0,0 +1,68 @@ + +ASN.1 library for Python +------------------------ + +This is an implementation of ASN.1 types and codecs in Python programming +language. It has been first written to support particular protocol (SNMP) +but then generalized to be suitable for a wide range of protocols +based on ASN.1 specification. + +FEATURES +-------- + +* Generic implementation of ASN.1 types (X.208) +* Fully standard compliant BER/CER/DER codecs +* 100% Python, works with Python 2.4 up to Python 3.3 (beta 1) +* MT-safe + +MISFEATURES +----------- + +* No ASN.1 compiler (by-hand ASN.1 spec compilation into Python code required) +* Codecs are not restartable + +INSTALLATION +------------ + +The pyasn1 package uses setuptools/distutils for installation. Thus do +either: + +$ easy_install pyasn1 + +or + +$ tar zxf pyasn1-0.1.3.tar.gz +$ cd pyasn1-0.1.3 +$ python setup.py install +$ cd test +$ python suite.py # run unit tests + +OPERATION +--------- + +Perhaps a typical use would involve [by-hand] compilation of your ASN.1 +specification into pyasn1-backed Python code at your application. + +For more information on pyasn1 APIs, please, refer to the +doc/pyasn1-tutorial.html file in the distribution. + +Also refer to example modules. Take a look at pyasn1-modules package -- maybe +it already holds something useful to you. + +AVAILABILITY +------------ + +The pyasn1 package is distributed under terms and conditions of BSD-style +license. See LICENSE file in the distribution. Source code is freely +available from: + +http://pyasn1.sf.net + + +FEEDBACK +-------- + +Please, send your comments and fixes to mailing lists at project web site. + +=-=-= +mailto: ilya@glas.net diff --git a/src/lib/pyasn1/THANKS b/src/lib/pyasn1/THANKS new file mode 100644 index 00000000..4de1713c --- /dev/null +++ b/src/lib/pyasn1/THANKS @@ -0,0 +1,4 @@ +Denis S. Otkidach +Gregory Golberg +Bud P. Bruegger +Jacek Konieczny diff --git a/src/lib/pyasn1/TODO b/src/lib/pyasn1/TODO new file mode 100644 index 00000000..0ee211c2 --- /dev/null +++ b/src/lib/pyasn1/TODO @@ -0,0 +1,36 @@ +* Specialize ASN.1 character and useful types +* Come up with simpler API for deeply nested constructed objects + addressing + +ber.decoder: +* suspend codec on underrun error ? +* class-static components map (in simple type classes) +* present subtypes ? +* component presence check wont work at innertypeconst +* add the rest of ASN1 types/codecs +* type vs value, defaultValue + +ber.encoder: +* Asn1Item.clone() / shallowcopy issue +* large length encoder? +* codec restart +* preserve compatible API whenever stateful codec gets implemented +* restartable vs incremental +* plan: make a stateless univeral decoder, then convert it to restartable + then to incremental + +type.useful: +* may need to implement prettyIn/Out + +type.char: +* may need to implement constraints + +type.univ: +* simpler API to constructed objects: value init, recursive + +type.namedtypes +* type vs tagset name convention + +general: + +* how untagged TagSet should be initialized? diff --git a/src/lib/pyasn1/__init__.py b/src/lib/pyasn1/__init__.py new file mode 100644 index 00000000..88aff79c --- /dev/null +++ b/src/lib/pyasn1/__init__.py @@ -0,0 +1,8 @@ +import sys + +# http://www.python.org/dev/peps/pep-0396/ +__version__ = '0.1.7' + +if sys.version_info[:2] < (2, 4): + raise RuntimeError('PyASN1 requires Python 2.4 or later') + diff --git a/src/lib/pyasn1/codec/__init__.py b/src/lib/pyasn1/codec/__init__.py new file mode 100644 index 00000000..8c3066b2 --- /dev/null +++ b/src/lib/pyasn1/codec/__init__.py @@ -0,0 +1 @@ +# This file is necessary to make this directory a package. diff --git a/src/lib/pyasn1/codec/ber/__init__.py b/src/lib/pyasn1/codec/ber/__init__.py new file mode 100644 index 00000000..8c3066b2 --- /dev/null +++ b/src/lib/pyasn1/codec/ber/__init__.py @@ -0,0 +1 @@ +# This file is necessary to make this directory a package. diff --git a/src/lib/pyasn1/codec/ber/decoder.py b/src/lib/pyasn1/codec/ber/decoder.py new file mode 100644 index 00000000..be0cf490 --- /dev/null +++ b/src/lib/pyasn1/codec/ber/decoder.py @@ -0,0 +1,808 @@ +# BER decoder +from pyasn1.type import tag, base, univ, char, useful, tagmap +from pyasn1.codec.ber import eoo +from pyasn1.compat.octets import oct2int, octs2ints, isOctetsType +from pyasn1 import debug, error + +class AbstractDecoder: + protoComponent = None + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + raise error.PyAsn1Error('Decoder not implemented for %s' % (tagSet,)) + + def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + raise error.PyAsn1Error('Indefinite length mode decoder not implemented for %s' % (tagSet,)) + +class AbstractSimpleDecoder(AbstractDecoder): + tagFormats = (tag.tagFormatSimple,) + def _createComponent(self, asn1Spec, tagSet, value=None): + if tagSet[0][1] not in self.tagFormats: + raise error.PyAsn1Error('Invalid tag format %r for %r' % (tagSet[0], self.protoComponent,)) + if asn1Spec is None: + return self.protoComponent.clone(value, tagSet) + elif value is None: + return asn1Spec + else: + return asn1Spec.clone(value) + +class AbstractConstructedDecoder(AbstractDecoder): + tagFormats = (tag.tagFormatConstructed,) + def _createComponent(self, asn1Spec, tagSet, value=None): + if tagSet[0][1] not in self.tagFormats: + raise error.PyAsn1Error('Invalid tag format %r for %r' % (tagSet[0], self.protoComponent,)) + if asn1Spec is None: + return self.protoComponent.clone(tagSet) + else: + return asn1Spec.clone() + +class EndOfOctetsDecoder(AbstractSimpleDecoder): + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + return eoo.endOfOctets, substrate[length:] + +class ExplicitTagDecoder(AbstractSimpleDecoder): + protoComponent = univ.Any('') + tagFormats = (tag.tagFormatConstructed,) + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + if substrateFun: + return substrateFun( + self._createComponent(asn1Spec, tagSet, ''), + substrate, length + ) + head, tail = substrate[:length], substrate[length:] + value, _ = decodeFun(head, asn1Spec, tagSet, length) + return value, tail + + def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + if substrateFun: + return substrateFun( + self._createComponent(asn1Spec, tagSet, ''), + substrate, length + ) + value, substrate = decodeFun(substrate, asn1Spec, tagSet, length) + terminator, substrate = decodeFun(substrate) + if eoo.endOfOctets.isSameTypeWith(terminator) and \ + terminator == eoo.endOfOctets: + return value, substrate + else: + raise error.PyAsn1Error('Missing end-of-octets terminator') + +explicitTagDecoder = ExplicitTagDecoder() + +class IntegerDecoder(AbstractSimpleDecoder): + protoComponent = univ.Integer(0) + precomputedValues = { + '\x00': 0, + '\x01': 1, + '\x02': 2, + '\x03': 3, + '\x04': 4, + '\x05': 5, + '\x06': 6, + '\x07': 7, + '\x08': 8, + '\x09': 9, + '\xff': -1, + '\xfe': -2, + '\xfd': -3, + '\xfc': -4, + '\xfb': -5 + } + + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, length, + state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + if not head: + return self._createComponent(asn1Spec, tagSet, 0), tail + if head in self.precomputedValues: + value = self.precomputedValues[head] + else: + firstOctet = oct2int(head[0]) + if firstOctet & 0x80: + value = -1 + else: + value = 0 + for octet in head: + value = value << 8 | oct2int(octet) + return self._createComponent(asn1Spec, tagSet, value), tail + +class BooleanDecoder(IntegerDecoder): + protoComponent = univ.Boolean(0) + def _createComponent(self, asn1Spec, tagSet, value=None): + return IntegerDecoder._createComponent(self, asn1Spec, tagSet, value and 1 or 0) + +class BitStringDecoder(AbstractSimpleDecoder): + protoComponent = univ.BitString(()) + tagFormats = (tag.tagFormatSimple, tag.tagFormatConstructed) + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, length, + state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + if tagSet[0][1] == tag.tagFormatSimple: # XXX what tag to check? + if not head: + raise error.PyAsn1Error('Empty substrate') + trailingBits = oct2int(head[0]) + if trailingBits > 7: + raise error.PyAsn1Error( + 'Trailing bits overflow %s' % trailingBits + ) + head = head[1:] + lsb = p = 0; l = len(head)-1; b = () + while p <= l: + if p == l: + lsb = trailingBits + j = 7 + o = oct2int(head[p]) + while j >= lsb: + b = b + ((o>>j)&0x01,) + j = j - 1 + p = p + 1 + return self._createComponent(asn1Spec, tagSet, b), tail + r = self._createComponent(asn1Spec, tagSet, ()) + if substrateFun: + return substrateFun(r, substrate, length) + while head: + component, head = decodeFun(head) + r = r + component + return r, tail + + def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + r = self._createComponent(asn1Spec, tagSet, '') + if substrateFun: + return substrateFun(r, substrate, length) + while substrate: + component, substrate = decodeFun(substrate) + if eoo.endOfOctets.isSameTypeWith(component) and \ + component == eoo.endOfOctets: + break + r = r + component + else: + raise error.SubstrateUnderrunError( + 'No EOO seen before substrate ends' + ) + return r, substrate + +class OctetStringDecoder(AbstractSimpleDecoder): + protoComponent = univ.OctetString('') + tagFormats = (tag.tagFormatSimple, tag.tagFormatConstructed) + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, length, + state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + if tagSet[0][1] == tag.tagFormatSimple: # XXX what tag to check? + return self._createComponent(asn1Spec, tagSet, head), tail + r = self._createComponent(asn1Spec, tagSet, '') + if substrateFun: + return substrateFun(r, substrate, length) + while head: + component, head = decodeFun(head) + r = r + component + return r, tail + + def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + r = self._createComponent(asn1Spec, tagSet, '') + if substrateFun: + return substrateFun(r, substrate, length) + while substrate: + component, substrate = decodeFun(substrate) + if eoo.endOfOctets.isSameTypeWith(component) and \ + component == eoo.endOfOctets: + break + r = r + component + else: + raise error.SubstrateUnderrunError( + 'No EOO seen before substrate ends' + ) + return r, substrate + +class NullDecoder(AbstractSimpleDecoder): + protoComponent = univ.Null('') + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + r = self._createComponent(asn1Spec, tagSet) + if head: + raise error.PyAsn1Error('Unexpected %d-octet substrate for Null' % length) + return r, tail + +class ObjectIdentifierDecoder(AbstractSimpleDecoder): + protoComponent = univ.ObjectIdentifier(()) + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, length, + state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + if not head: + raise error.PyAsn1Error('Empty substrate') + + # Get the first subid + subId = oct2int(head[0]) + oid = divmod(subId, 40) + + index = 1 + substrateLen = len(head) + while index < substrateLen: + subId = oct2int(head[index]) + index = index + 1 + if subId == 128: + # ASN.1 spec forbids leading zeros (0x80) in sub-ID OID + # encoding, tolerating it opens a vulnerability. + # See http://www.cosic.esat.kuleuven.be/publications/article-1432.pdf page 7 + raise error.PyAsn1Error('Invalid leading 0x80 in sub-OID') + elif subId > 128: + # Construct subid from a number of octets + nextSubId = subId + subId = 0 + while nextSubId >= 128: + subId = (subId << 7) + (nextSubId & 0x7F) + if index >= substrateLen: + raise error.SubstrateUnderrunError( + 'Short substrate for sub-OID past %s' % (oid,) + ) + nextSubId = oct2int(head[index]) + index = index + 1 + subId = (subId << 7) + nextSubId + oid = oid + (subId,) + return self._createComponent(asn1Spec, tagSet, oid), tail + +class RealDecoder(AbstractSimpleDecoder): + protoComponent = univ.Real() + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + if not head: + return self._createComponent(asn1Spec, tagSet, 0.0), tail + fo = oct2int(head[0]); head = head[1:] + if fo & 0x80: # binary enoding + n = (fo & 0x03) + 1 + if n == 4: + n = oct2int(head[0]) + eo, head = head[:n], head[n:] + if not eo or not head: + raise error.PyAsn1Error('Real exponent screwed') + e = oct2int(eo[0]) & 0x80 and -1 or 0 + while eo: # exponent + e <<= 8 + e |= oct2int(eo[0]) + eo = eo[1:] + p = 0 + while head: # value + p <<= 8 + p |= oct2int(head[0]) + head = head[1:] + if fo & 0x40: # sign bit + p = -p + value = (p, 2, e) + elif fo & 0x40: # infinite value + value = fo & 0x01 and '-inf' or 'inf' + elif fo & 0xc0 == 0: # character encoding + try: + if fo & 0x3 == 0x1: # NR1 + value = (int(head), 10, 0) + elif fo & 0x3 == 0x2: # NR2 + value = float(head) + elif fo & 0x3 == 0x3: # NR3 + value = float(head) + else: + raise error.SubstrateUnderrunError( + 'Unknown NR (tag %s)' % fo + ) + except ValueError: + raise error.SubstrateUnderrunError( + 'Bad character Real syntax' + ) + else: + raise error.SubstrateUnderrunError( + 'Unknown encoding (tag %s)' % fo + ) + return self._createComponent(asn1Spec, tagSet, value), tail + +class SequenceDecoder(AbstractConstructedDecoder): + protoComponent = univ.Sequence() + def _getComponentTagMap(self, r, idx): + try: + return r.getComponentTagMapNearPosition(idx) + except error.PyAsn1Error: + return + + def _getComponentPositionByType(self, r, t, idx): + return r.getComponentPositionNearType(t, idx) + + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + r = self._createComponent(asn1Spec, tagSet) + idx = 0 + if substrateFun: + return substrateFun(r, substrate, length) + while head: + asn1Spec = self._getComponentTagMap(r, idx) + component, head = decodeFun(head, asn1Spec) + idx = self._getComponentPositionByType( + r, component.getEffectiveTagSet(), idx + ) + r.setComponentByPosition(idx, component, asn1Spec is None) + idx = idx + 1 + r.setDefaultComponents() + r.verifySizeSpec() + return r, tail + + def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + r = self._createComponent(asn1Spec, tagSet) + if substrateFun: + return substrateFun(r, substrate, length) + idx = 0 + while substrate: + asn1Spec = self._getComponentTagMap(r, idx) + component, substrate = decodeFun(substrate, asn1Spec) + if eoo.endOfOctets.isSameTypeWith(component) and \ + component == eoo.endOfOctets: + break + idx = self._getComponentPositionByType( + r, component.getEffectiveTagSet(), idx + ) + r.setComponentByPosition(idx, component, asn1Spec is None) + idx = idx + 1 + else: + raise error.SubstrateUnderrunError( + 'No EOO seen before substrate ends' + ) + r.setDefaultComponents() + r.verifySizeSpec() + return r, substrate + +class SequenceOfDecoder(AbstractConstructedDecoder): + protoComponent = univ.SequenceOf() + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + r = self._createComponent(asn1Spec, tagSet) + if substrateFun: + return substrateFun(r, substrate, length) + asn1Spec = r.getComponentType() + idx = 0 + while head: + component, head = decodeFun(head, asn1Spec) + r.setComponentByPosition(idx, component, asn1Spec is None) + idx = idx + 1 + r.verifySizeSpec() + return r, tail + + def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + r = self._createComponent(asn1Spec, tagSet) + if substrateFun: + return substrateFun(r, substrate, length) + asn1Spec = r.getComponentType() + idx = 0 + while substrate: + component, substrate = decodeFun(substrate, asn1Spec) + if eoo.endOfOctets.isSameTypeWith(component) and \ + component == eoo.endOfOctets: + break + r.setComponentByPosition(idx, component, asn1Spec is None) + idx = idx + 1 + else: + raise error.SubstrateUnderrunError( + 'No EOO seen before substrate ends' + ) + r.verifySizeSpec() + return r, substrate + +class SetDecoder(SequenceDecoder): + protoComponent = univ.Set() + def _getComponentTagMap(self, r, idx): + return r.getComponentTagMap() + + def _getComponentPositionByType(self, r, t, idx): + nextIdx = r.getComponentPositionByType(t) + if nextIdx is None: + return idx + else: + return nextIdx + +class SetOfDecoder(SequenceOfDecoder): + protoComponent = univ.SetOf() + +class ChoiceDecoder(AbstractConstructedDecoder): + protoComponent = univ.Choice() + tagFormats = (tag.tagFormatSimple, tag.tagFormatConstructed) + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + r = self._createComponent(asn1Spec, tagSet) + if substrateFun: + return substrateFun(r, substrate, length) + if r.getTagSet() == tagSet: # explicitly tagged Choice + component, head = decodeFun( + head, r.getComponentTagMap() + ) + else: + component, head = decodeFun( + head, r.getComponentTagMap(), tagSet, length, state + ) + if isinstance(component, univ.Choice): + effectiveTagSet = component.getEffectiveTagSet() + else: + effectiveTagSet = component.getTagSet() + r.setComponentByType(effectiveTagSet, component, 0, asn1Spec is None) + return r, tail + + def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + r = self._createComponent(asn1Spec, tagSet) + if substrateFun: + return substrateFun(r, substrate, length) + if r.getTagSet() == tagSet: # explicitly tagged Choice + component, substrate = decodeFun(substrate, r.getComponentTagMap()) + eooMarker, substrate = decodeFun(substrate) # eat up EOO marker + if not eoo.endOfOctets.isSameTypeWith(eooMarker) or \ + eooMarker != eoo.endOfOctets: + raise error.PyAsn1Error('No EOO seen before substrate ends') + else: + component, substrate= decodeFun( + substrate, r.getComponentTagMap(), tagSet, length, state + ) + if isinstance(component, univ.Choice): + effectiveTagSet = component.getEffectiveTagSet() + else: + effectiveTagSet = component.getTagSet() + r.setComponentByType(effectiveTagSet, component, 0, asn1Spec is None) + return r, substrate + +class AnyDecoder(AbstractSimpleDecoder): + protoComponent = univ.Any() + tagFormats = (tag.tagFormatSimple, tag.tagFormatConstructed) + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + if asn1Spec is None or \ + asn1Spec is not None and tagSet != asn1Spec.getTagSet(): + # untagged Any container, recover inner header substrate + length = length + len(fullSubstrate) - len(substrate) + substrate = fullSubstrate + if substrateFun: + return substrateFun(self._createComponent(asn1Spec, tagSet), + substrate, length) + head, tail = substrate[:length], substrate[length:] + return self._createComponent(asn1Spec, tagSet, value=head), tail + + def indefLenValueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, + length, state, decodeFun, substrateFun): + if asn1Spec is not None and tagSet == asn1Spec.getTagSet(): + # tagged Any type -- consume header substrate + header = '' + else: + # untagged Any, recover header substrate + header = fullSubstrate[:-len(substrate)] + + r = self._createComponent(asn1Spec, tagSet, header) + + # Any components do not inherit initial tag + asn1Spec = self.protoComponent + + if substrateFun: + return substrateFun(r, substrate, length) + while substrate: + component, substrate = decodeFun(substrate, asn1Spec) + if eoo.endOfOctets.isSameTypeWith(component) and \ + component == eoo.endOfOctets: + break + r = r + component + else: + raise error.SubstrateUnderrunError( + 'No EOO seen before substrate ends' + ) + return r, substrate + +# character string types +class UTF8StringDecoder(OctetStringDecoder): + protoComponent = char.UTF8String() +class NumericStringDecoder(OctetStringDecoder): + protoComponent = char.NumericString() +class PrintableStringDecoder(OctetStringDecoder): + protoComponent = char.PrintableString() +class TeletexStringDecoder(OctetStringDecoder): + protoComponent = char.TeletexString() +class VideotexStringDecoder(OctetStringDecoder): + protoComponent = char.VideotexString() +class IA5StringDecoder(OctetStringDecoder): + protoComponent = char.IA5String() +class GraphicStringDecoder(OctetStringDecoder): + protoComponent = char.GraphicString() +class VisibleStringDecoder(OctetStringDecoder): + protoComponent = char.VisibleString() +class GeneralStringDecoder(OctetStringDecoder): + protoComponent = char.GeneralString() +class UniversalStringDecoder(OctetStringDecoder): + protoComponent = char.UniversalString() +class BMPStringDecoder(OctetStringDecoder): + protoComponent = char.BMPString() + +# "useful" types +class GeneralizedTimeDecoder(OctetStringDecoder): + protoComponent = useful.GeneralizedTime() +class UTCTimeDecoder(OctetStringDecoder): + protoComponent = useful.UTCTime() + +tagMap = { + eoo.endOfOctets.tagSet: EndOfOctetsDecoder(), + univ.Integer.tagSet: IntegerDecoder(), + univ.Boolean.tagSet: BooleanDecoder(), + univ.BitString.tagSet: BitStringDecoder(), + univ.OctetString.tagSet: OctetStringDecoder(), + univ.Null.tagSet: NullDecoder(), + univ.ObjectIdentifier.tagSet: ObjectIdentifierDecoder(), + univ.Enumerated.tagSet: IntegerDecoder(), + univ.Real.tagSet: RealDecoder(), + univ.Sequence.tagSet: SequenceDecoder(), # conflicts with SequenceOf + univ.Set.tagSet: SetDecoder(), # conflicts with SetOf + univ.Choice.tagSet: ChoiceDecoder(), # conflicts with Any + # character string types + char.UTF8String.tagSet: UTF8StringDecoder(), + char.NumericString.tagSet: NumericStringDecoder(), + char.PrintableString.tagSet: PrintableStringDecoder(), + char.TeletexString.tagSet: TeletexStringDecoder(), + char.VideotexString.tagSet: VideotexStringDecoder(), + char.IA5String.tagSet: IA5StringDecoder(), + char.GraphicString.tagSet: GraphicStringDecoder(), + char.VisibleString.tagSet: VisibleStringDecoder(), + char.GeneralString.tagSet: GeneralStringDecoder(), + char.UniversalString.tagSet: UniversalStringDecoder(), + char.BMPString.tagSet: BMPStringDecoder(), + # useful types + useful.GeneralizedTime.tagSet: GeneralizedTimeDecoder(), + useful.UTCTime.tagSet: UTCTimeDecoder() + } + +# Type-to-codec map for ambiguous ASN.1 types +typeMap = { + univ.Set.typeId: SetDecoder(), + univ.SetOf.typeId: SetOfDecoder(), + univ.Sequence.typeId: SequenceDecoder(), + univ.SequenceOf.typeId: SequenceOfDecoder(), + univ.Choice.typeId: ChoiceDecoder(), + univ.Any.typeId: AnyDecoder() + } + +( stDecodeTag, stDecodeLength, stGetValueDecoder, stGetValueDecoderByAsn1Spec, + stGetValueDecoderByTag, stTryAsExplicitTag, stDecodeValue, + stDumpRawValue, stErrorCondition, stStop ) = [x for x in range(10)] + +class Decoder: + defaultErrorState = stErrorCondition +# defaultErrorState = stDumpRawValue + defaultRawDecoder = AnyDecoder() + def __init__(self, tagMap, typeMap={}): + self.__tagMap = tagMap + self.__typeMap = typeMap + self.__endOfOctetsTagSet = eoo.endOfOctets.getTagSet() + # Tag & TagSet objects caches + self.__tagCache = {} + self.__tagSetCache = {} + + def __call__(self, substrate, asn1Spec=None, tagSet=None, + length=None, state=stDecodeTag, recursiveFlag=1, + substrateFun=None): + if debug.logger & debug.flagDecoder: + debug.logger('decoder called at scope %s with state %d, working with up to %d octets of substrate: %s' % (debug.scope, state, len(substrate), debug.hexdump(substrate))) + fullSubstrate = substrate + while state != stStop: + if state == stDecodeTag: + # Decode tag + if not substrate: + raise error.SubstrateUnderrunError( + 'Short octet stream on tag decoding' + ) + if not isOctetsType(substrate) and \ + not isinstance(substrate, univ.OctetString): + raise error.PyAsn1Error('Bad octet stream type') + + firstOctet = substrate[0] + substrate = substrate[1:] + if firstOctet in self.__tagCache: + lastTag = self.__tagCache[firstOctet] + else: + t = oct2int(firstOctet) + tagClass = t&0xC0 + tagFormat = t&0x20 + tagId = t&0x1F + if tagId == 0x1F: + tagId = 0 + while 1: + if not substrate: + raise error.SubstrateUnderrunError( + 'Short octet stream on long tag decoding' + ) + t = oct2int(substrate[0]) + tagId = tagId << 7 | (t&0x7F) + substrate = substrate[1:] + if not t&0x80: + break + lastTag = tag.Tag( + tagClass=tagClass, tagFormat=tagFormat, tagId=tagId + ) + if tagId < 31: + # cache short tags + self.__tagCache[firstOctet] = lastTag + if tagSet is None: + if firstOctet in self.__tagSetCache: + tagSet = self.__tagSetCache[firstOctet] + else: + # base tag not recovered + tagSet = tag.TagSet((), lastTag) + if firstOctet in self.__tagCache: + self.__tagSetCache[firstOctet] = tagSet + else: + tagSet = lastTag + tagSet + state = stDecodeLength + debug.logger and debug.logger & debug.flagDecoder and debug.logger('tag decoded into %r, decoding length' % tagSet) + if state == stDecodeLength: + # Decode length + if not substrate: + raise error.SubstrateUnderrunError( + 'Short octet stream on length decoding' + ) + firstOctet = oct2int(substrate[0]) + if firstOctet == 128: + size = 1 + length = -1 + elif firstOctet < 128: + length, size = firstOctet, 1 + else: + size = firstOctet & 0x7F + # encoded in size bytes + length = 0 + lengthString = substrate[1:size+1] + # missing check on maximum size, which shouldn't be a + # problem, we can handle more than is possible + if len(lengthString) != size: + raise error.SubstrateUnderrunError( + '%s<%s at %s' % + (size, len(lengthString), tagSet) + ) + for char in lengthString: + length = (length << 8) | oct2int(char) + size = size + 1 + substrate = substrate[size:] + if length != -1 and len(substrate) < length: + raise error.SubstrateUnderrunError( + '%d-octet short' % (length - len(substrate)) + ) + state = stGetValueDecoder + debug.logger and debug.logger & debug.flagDecoder and debug.logger('value length decoded into %d, payload substrate is: %s' % (length, debug.hexdump(length == -1 and substrate or substrate[:length]))) + if state == stGetValueDecoder: + if asn1Spec is None: + state = stGetValueDecoderByTag + else: + state = stGetValueDecoderByAsn1Spec + # + # There're two ways of creating subtypes in ASN.1 what influences + # decoder operation. These methods are: + # 1) Either base types used in or no IMPLICIT tagging has been + # applied on subtyping. + # 2) Subtype syntax drops base type information (by means of + # IMPLICIT tagging. + # The first case allows for complete tag recovery from substrate + # while the second one requires original ASN.1 type spec for + # decoding. + # + # In either case a set of tags (tagSet) is coming from substrate + # in an incremental, tag-by-tag fashion (this is the case of + # EXPLICIT tag which is most basic). Outermost tag comes first + # from the wire. + # + if state == stGetValueDecoderByTag: + if tagSet in self.__tagMap: + concreteDecoder = self.__tagMap[tagSet] + else: + concreteDecoder = None + if concreteDecoder: + state = stDecodeValue + else: + _k = tagSet[:1] + if _k in self.__tagMap: + concreteDecoder = self.__tagMap[_k] + else: + concreteDecoder = None + if concreteDecoder: + state = stDecodeValue + else: + state = stTryAsExplicitTag + if debug.logger and debug.logger & debug.flagDecoder: + debug.logger('codec %s chosen by a built-in type, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "", state == stDecodeValue and 'value' or 'as explicit tag')) + debug.scope.push(concreteDecoder is None and '?' or concreteDecoder.protoComponent.__class__.__name__) + if state == stGetValueDecoderByAsn1Spec: + if isinstance(asn1Spec, (dict, tagmap.TagMap)): + if tagSet in asn1Spec: + __chosenSpec = asn1Spec[tagSet] + else: + __chosenSpec = None + if debug.logger and debug.logger & debug.flagDecoder: + debug.logger('candidate ASN.1 spec is a map of:') + for t, v in asn1Spec.getPosMap().items(): + debug.logger(' %r -> %s' % (t, v.__class__.__name__)) + if asn1Spec.getNegMap(): + debug.logger('but neither of: ') + for i in asn1Spec.getNegMap().items(): + debug.logger(' %r -> %s' % (t, v.__class__.__name__)) + debug.logger('new candidate ASN.1 spec is %s, chosen by %r' % (__chosenSpec is None and '' or __chosenSpec.__class__.__name__, tagSet)) + else: + __chosenSpec = asn1Spec + debug.logger and debug.logger & debug.flagDecoder and debug.logger('candidate ASN.1 spec is %s' % asn1Spec.__class__.__name__) + if __chosenSpec is not None and ( + tagSet == __chosenSpec.getTagSet() or \ + tagSet in __chosenSpec.getTagMap() + ): + # use base type for codec lookup to recover untagged types + baseTagSet = __chosenSpec.baseTagSet + if __chosenSpec.typeId is not None and \ + __chosenSpec.typeId in self.__typeMap: + # ambiguous type + concreteDecoder = self.__typeMap[__chosenSpec.typeId] + debug.logger and debug.logger & debug.flagDecoder and debug.logger('value decoder chosen for an ambiguous type by type ID %s' % (__chosenSpec.typeId,)) + elif baseTagSet in self.__tagMap: + # base type or tagged subtype + concreteDecoder = self.__tagMap[baseTagSet] + debug.logger and debug.logger & debug.flagDecoder and debug.logger('value decoder chosen by base %r' % (baseTagSet,)) + else: + concreteDecoder = None + if concreteDecoder: + asn1Spec = __chosenSpec + state = stDecodeValue + else: + state = stTryAsExplicitTag + elif tagSet == self.__endOfOctetsTagSet: + concreteDecoder = self.__tagMap[tagSet] + state = stDecodeValue + debug.logger and debug.logger & debug.flagDecoder and debug.logger('end-of-octets found') + else: + concreteDecoder = None + state = stTryAsExplicitTag + if debug.logger and debug.logger & debug.flagDecoder: + debug.logger('codec %s chosen by ASN.1 spec, decoding %s' % (state == stDecodeValue and concreteDecoder.__class__.__name__ or "", state == stDecodeValue and 'value' or 'as explicit tag')) + debug.scope.push(__chosenSpec is None and '?' or __chosenSpec.__class__.__name__) + if state == stTryAsExplicitTag: + if tagSet and \ + tagSet[0][1] == tag.tagFormatConstructed and \ + tagSet[0][0] != tag.tagClassUniversal: + # Assume explicit tagging + concreteDecoder = explicitTagDecoder + state = stDecodeValue + else: + concreteDecoder = None + state = self.defaultErrorState + debug.logger and debug.logger & debug.flagDecoder and debug.logger('codec %s chosen, decoding %s' % (concreteDecoder and concreteDecoder.__class__.__name__ or "", state == stDecodeValue and 'value' or 'as failure')) + if state == stDumpRawValue: + concreteDecoder = self.defaultRawDecoder + debug.logger and debug.logger & debug.flagDecoder and debug.logger('codec %s chosen, decoding value' % concreteDecoder.__class__.__name__) + state = stDecodeValue + if state == stDecodeValue: + if recursiveFlag == 0 and not substrateFun: # legacy + substrateFun = lambda a,b,c: (a,b[:c]) + if length == -1: # indef length + value, substrate = concreteDecoder.indefLenValueDecoder( + fullSubstrate, substrate, asn1Spec, tagSet, length, + stGetValueDecoder, self, substrateFun + ) + else: + value, substrate = concreteDecoder.valueDecoder( + fullSubstrate, substrate, asn1Spec, tagSet, length, + stGetValueDecoder, self, substrateFun + ) + state = stStop + debug.logger and debug.logger & debug.flagDecoder and debug.logger('codec %s yields type %s, value:\n%s\n...remaining substrate is: %s' % (concreteDecoder.__class__.__name__, value.__class__.__name__, value.prettyPrint(), substrate and debug.hexdump(substrate) or '')) + if state == stErrorCondition: + raise error.PyAsn1Error( + '%r not in asn1Spec: %r' % (tagSet, asn1Spec) + ) + if debug.logger and debug.logger & debug.flagDecoder: + debug.scope.pop() + debug.logger('decoder left scope %s, call completed' % debug.scope) + return value, substrate + +decode = Decoder(tagMap, typeMap) + +# XXX +# non-recursive decoding; return position rather than substrate diff --git a/src/lib/pyasn1/codec/ber/encoder.py b/src/lib/pyasn1/codec/ber/encoder.py new file mode 100644 index 00000000..173949d0 --- /dev/null +++ b/src/lib/pyasn1/codec/ber/encoder.py @@ -0,0 +1,353 @@ +# BER encoder +from pyasn1.type import base, tag, univ, char, useful +from pyasn1.codec.ber import eoo +from pyasn1.compat.octets import int2oct, oct2int, ints2octs, null, str2octs +from pyasn1 import debug, error + +class Error(Exception): pass + +class AbstractItemEncoder: + supportIndefLenMode = 1 + def encodeTag(self, t, isConstructed): + tagClass, tagFormat, tagId = t.asTuple() # this is a hotspot + v = tagClass | tagFormat + if isConstructed: + v = v|tag.tagFormatConstructed + if tagId < 31: + return int2oct(v|tagId) + else: + s = int2oct(tagId&0x7f) + tagId = tagId >> 7 + while tagId: + s = int2oct(0x80|(tagId&0x7f)) + s + tagId = tagId >> 7 + return int2oct(v|0x1F) + s + + def encodeLength(self, length, defMode): + if not defMode and self.supportIndefLenMode: + return int2oct(0x80) + if length < 0x80: + return int2oct(length) + else: + substrate = null + while length: + substrate = int2oct(length&0xff) + substrate + length = length >> 8 + substrateLen = len(substrate) + if substrateLen > 126: + raise Error('Length octets overflow (%d)' % substrateLen) + return int2oct(0x80 | substrateLen) + substrate + + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + raise Error('Not implemented') + + def _encodeEndOfOctets(self, encodeFun, defMode): + if defMode or not self.supportIndefLenMode: + return null + else: + return encodeFun(eoo.endOfOctets, defMode) + + def encode(self, encodeFun, value, defMode, maxChunkSize): + substrate, isConstructed = self.encodeValue( + encodeFun, value, defMode, maxChunkSize + ) + tagSet = value.getTagSet() + if tagSet: + if not isConstructed: # primitive form implies definite mode + defMode = 1 + return self.encodeTag( + tagSet[-1], isConstructed + ) + self.encodeLength( + len(substrate), defMode + ) + substrate + self._encodeEndOfOctets(encodeFun, defMode) + else: + return substrate # untagged value + +class EndOfOctetsEncoder(AbstractItemEncoder): + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + return null, 0 + +class ExplicitlyTaggedItemEncoder(AbstractItemEncoder): + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + if isinstance(value, base.AbstractConstructedAsn1Item): + value = value.clone(tagSet=value.getTagSet()[:-1], + cloneValueFlag=1) + else: + value = value.clone(tagSet=value.getTagSet()[:-1]) + return encodeFun(value, defMode, maxChunkSize), 1 + +explicitlyTaggedItemEncoder = ExplicitlyTaggedItemEncoder() + +class BooleanEncoder(AbstractItemEncoder): + supportIndefLenMode = 0 + _true = ints2octs((1,)) + _false = ints2octs((0,)) + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + return value and self._true or self._false, 0 + +class IntegerEncoder(AbstractItemEncoder): + supportIndefLenMode = 0 + supportCompactZero = False + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + if value == 0: # shortcut for zero value + if self.supportCompactZero: + # this seems to be a correct way for encoding zeros + return null, 0 + else: + # this seems to be a widespread way for encoding zeros + return ints2octs((0,)), 0 + octets = [] + value = int(value) # to save on ops on asn1 type + while 1: + octets.insert(0, value & 0xff) + if value == 0 or value == -1: + break + value = value >> 8 + if value == 0 and octets[0] & 0x80: + octets.insert(0, 0) + while len(octets) > 1 and \ + (octets[0] == 0 and octets[1] & 0x80 == 0 or \ + octets[0] == 0xff and octets[1] & 0x80 != 0): + del octets[0] + return ints2octs(octets), 0 + +class BitStringEncoder(AbstractItemEncoder): + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + if not maxChunkSize or len(value) <= maxChunkSize*8: + r = {}; l = len(value); p = 0; j = 7 + while p < l: + i, j = divmod(p, 8) + r[i] = r.get(i,0) | value[p]<<(7-j) + p = p + 1 + keys = list(r); keys.sort() + return int2oct(7-j) + ints2octs([r[k] for k in keys]), 0 + else: + pos = 0; substrate = null + while 1: + # count in octets + v = value.clone(value[pos*8:pos*8+maxChunkSize*8]) + if not v: + break + substrate = substrate + encodeFun(v, defMode, maxChunkSize) + pos = pos + maxChunkSize + return substrate, 1 + +class OctetStringEncoder(AbstractItemEncoder): + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + if not maxChunkSize or len(value) <= maxChunkSize: + return value.asOctets(), 0 + else: + pos = 0; substrate = null + while 1: + v = value.clone(value[pos:pos+maxChunkSize]) + if not v: + break + substrate = substrate + encodeFun(v, defMode, maxChunkSize) + pos = pos + maxChunkSize + return substrate, 1 + +class NullEncoder(AbstractItemEncoder): + supportIndefLenMode = 0 + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + return null, 0 + +class ObjectIdentifierEncoder(AbstractItemEncoder): + supportIndefLenMode = 0 + precomputedValues = { + (1, 3, 6, 1, 2): (43, 6, 1, 2), + (1, 3, 6, 1, 4): (43, 6, 1, 4) + } + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + oid = value.asTuple() + if oid[:5] in self.precomputedValues: + octets = self.precomputedValues[oid[:5]] + index = 5 + else: + if len(oid) < 2: + raise error.PyAsn1Error('Short OID %s' % (value,)) + + # Build the first twos + if oid[0] > 6 or oid[1] > 39 or oid[0] == 6 and oid[1] > 15: + raise error.PyAsn1Error( + 'Initial sub-ID overflow %s in OID %s' % (oid[:2], value) + ) + octets = (oid[0] * 40 + oid[1],) + index = 2 + + # Cycle through subids + for subid in oid[index:]: + if subid > -1 and subid < 128: + # Optimize for the common case + octets = octets + (subid & 0x7f,) + elif subid < 0 or subid > 0xFFFFFFFF: + raise error.PyAsn1Error( + 'SubId overflow %s in %s' % (subid, value) + ) + else: + # Pack large Sub-Object IDs + res = (subid & 0x7f,) + subid = subid >> 7 + while subid > 0: + res = (0x80 | (subid & 0x7f),) + res + subid = subid >> 7 + # Add packed Sub-Object ID to resulted Object ID + octets += res + + return ints2octs(octets), 0 + +class RealEncoder(AbstractItemEncoder): + supportIndefLenMode = 0 + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + if value.isPlusInfinity(): + return int2oct(0x40), 0 + if value.isMinusInfinity(): + return int2oct(0x41), 0 + m, b, e = value + if not m: + return null, 0 + if b == 10: + return str2octs('\x03%dE%s%d' % (m, e == 0 and '+' or '', e)), 0 + elif b == 2: + fo = 0x80 # binary enoding + if m < 0: + fo = fo | 0x40 # sign bit + m = -m + while int(m) != m: # drop floating point + m *= 2 + e -= 1 + while m & 0x1 == 0: # mantissa normalization + m >>= 1 + e += 1 + eo = null + while e not in (0, -1): + eo = int2oct(e&0xff) + eo + e >>= 8 + if e == 0 and eo and oct2int(eo[0]) & 0x80: + eo = int2oct(0) + eo + n = len(eo) + if n > 0xff: + raise error.PyAsn1Error('Real exponent overflow') + if n == 1: + pass + elif n == 2: + fo |= 1 + elif n == 3: + fo |= 2 + else: + fo |= 3 + eo = int2oct(n//0xff+1) + eo + po = null + while m: + po = int2oct(m&0xff) + po + m >>= 8 + substrate = int2oct(fo) + eo + po + return substrate, 0 + else: + raise error.PyAsn1Error('Prohibited Real base %s' % b) + +class SequenceEncoder(AbstractItemEncoder): + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + value.setDefaultComponents() + value.verifySizeSpec() + substrate = null; idx = len(value) + while idx > 0: + idx = idx - 1 + if value[idx] is None: # Optional component + continue + component = value.getDefaultComponentByPosition(idx) + if component is not None and component == value[idx]: + continue + substrate = encodeFun( + value[idx], defMode, maxChunkSize + ) + substrate + return substrate, 1 + +class SequenceOfEncoder(AbstractItemEncoder): + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + value.verifySizeSpec() + substrate = null; idx = len(value) + while idx > 0: + idx = idx - 1 + substrate = encodeFun( + value[idx], defMode, maxChunkSize + ) + substrate + return substrate, 1 + +class ChoiceEncoder(AbstractItemEncoder): + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + return encodeFun(value.getComponent(), defMode, maxChunkSize), 1 + +class AnyEncoder(OctetStringEncoder): + def encodeValue(self, encodeFun, value, defMode, maxChunkSize): + return value.asOctets(), defMode == 0 + +tagMap = { + eoo.endOfOctets.tagSet: EndOfOctetsEncoder(), + univ.Boolean.tagSet: BooleanEncoder(), + univ.Integer.tagSet: IntegerEncoder(), + univ.BitString.tagSet: BitStringEncoder(), + univ.OctetString.tagSet: OctetStringEncoder(), + univ.Null.tagSet: NullEncoder(), + univ.ObjectIdentifier.tagSet: ObjectIdentifierEncoder(), + univ.Enumerated.tagSet: IntegerEncoder(), + univ.Real.tagSet: RealEncoder(), + # Sequence & Set have same tags as SequenceOf & SetOf + univ.SequenceOf.tagSet: SequenceOfEncoder(), + univ.SetOf.tagSet: SequenceOfEncoder(), + univ.Choice.tagSet: ChoiceEncoder(), + # character string types + char.UTF8String.tagSet: OctetStringEncoder(), + char.NumericString.tagSet: OctetStringEncoder(), + char.PrintableString.tagSet: OctetStringEncoder(), + char.TeletexString.tagSet: OctetStringEncoder(), + char.VideotexString.tagSet: OctetStringEncoder(), + char.IA5String.tagSet: OctetStringEncoder(), + char.GraphicString.tagSet: OctetStringEncoder(), + char.VisibleString.tagSet: OctetStringEncoder(), + char.GeneralString.tagSet: OctetStringEncoder(), + char.UniversalString.tagSet: OctetStringEncoder(), + char.BMPString.tagSet: OctetStringEncoder(), + # useful types + useful.GeneralizedTime.tagSet: OctetStringEncoder(), + useful.UTCTime.tagSet: OctetStringEncoder() + } + +# Type-to-codec map for ambiguous ASN.1 types +typeMap = { + univ.Set.typeId: SequenceEncoder(), + univ.SetOf.typeId: SequenceOfEncoder(), + univ.Sequence.typeId: SequenceEncoder(), + univ.SequenceOf.typeId: SequenceOfEncoder(), + univ.Choice.typeId: ChoiceEncoder(), + univ.Any.typeId: AnyEncoder() + } + +class Encoder: + def __init__(self, tagMap, typeMap={}): + self.__tagMap = tagMap + self.__typeMap = typeMap + + def __call__(self, value, defMode=1, maxChunkSize=0): + debug.logger & debug.flagEncoder and debug.logger('encoder called in %sdef mode, chunk size %s for type %s, value:\n%s' % (not defMode and 'in' or '', maxChunkSize, value.__class__.__name__, value.prettyPrint())) + tagSet = value.getTagSet() + if len(tagSet) > 1: + concreteEncoder = explicitlyTaggedItemEncoder + else: + if value.typeId is not None and value.typeId in self.__typeMap: + concreteEncoder = self.__typeMap[value.typeId] + elif tagSet in self.__tagMap: + concreteEncoder = self.__tagMap[tagSet] + else: + tagSet = value.baseTagSet + if tagSet in self.__tagMap: + concreteEncoder = self.__tagMap[tagSet] + else: + raise Error('No encoder for %s' % (value,)) + debug.logger & debug.flagEncoder and debug.logger('using value codec %s chosen by %r' % (concreteEncoder.__class__.__name__, tagSet)) + substrate = concreteEncoder.encode( + self, value, defMode, maxChunkSize + ) + debug.logger & debug.flagEncoder and debug.logger('built %s octets of substrate: %s\nencoder completed' % (len(substrate), debug.hexdump(substrate))) + return substrate + +encode = Encoder(tagMap, typeMap) diff --git a/src/lib/pyasn1/codec/ber/eoo.py b/src/lib/pyasn1/codec/ber/eoo.py new file mode 100644 index 00000000..379be199 --- /dev/null +++ b/src/lib/pyasn1/codec/ber/eoo.py @@ -0,0 +1,8 @@ +from pyasn1.type import base, tag + +class EndOfOctets(base.AbstractSimpleAsn1Item): + defaultValue = 0 + tagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x00) + ) +endOfOctets = EndOfOctets() diff --git a/src/lib/pyasn1/codec/cer/__init__.py b/src/lib/pyasn1/codec/cer/__init__.py new file mode 100644 index 00000000..8c3066b2 --- /dev/null +++ b/src/lib/pyasn1/codec/cer/__init__.py @@ -0,0 +1 @@ +# This file is necessary to make this directory a package. diff --git a/src/lib/pyasn1/codec/cer/decoder.py b/src/lib/pyasn1/codec/cer/decoder.py new file mode 100644 index 00000000..9fd37c13 --- /dev/null +++ b/src/lib/pyasn1/codec/cer/decoder.py @@ -0,0 +1,35 @@ +# CER decoder +from pyasn1.type import univ +from pyasn1.codec.ber import decoder +from pyasn1.compat.octets import oct2int +from pyasn1 import error + +class BooleanDecoder(decoder.AbstractSimpleDecoder): + protoComponent = univ.Boolean(0) + def valueDecoder(self, fullSubstrate, substrate, asn1Spec, tagSet, length, + state, decodeFun, substrateFun): + head, tail = substrate[:length], substrate[length:] + if not head: + raise error.PyAsn1Error('Empty substrate') + byte = oct2int(head[0]) + # CER/DER specifies encoding of TRUE as 0xFF and FALSE as 0x0, while + # BER allows any non-zero value as TRUE; cf. sections 8.2.2. and 11.1 + # in http://www.itu.int/ITU-T/studygroups/com17/languages/X.690-0207.pdf + if byte == 0xff: + value = 1 + elif byte == 0x00: + value = 0 + else: + raise error.PyAsn1Error('Boolean CER violation: %s' % byte) + return self._createComponent(asn1Spec, tagSet, value), tail + +tagMap = decoder.tagMap.copy() +tagMap.update({ + univ.Boolean.tagSet: BooleanDecoder() + }) + +typeMap = decoder.typeMap + +class Decoder(decoder.Decoder): pass + +decode = Decoder(tagMap, decoder.typeMap) diff --git a/src/lib/pyasn1/codec/cer/encoder.py b/src/lib/pyasn1/codec/cer/encoder.py new file mode 100644 index 00000000..4c05130a --- /dev/null +++ b/src/lib/pyasn1/codec/cer/encoder.py @@ -0,0 +1,87 @@ +# CER encoder +from pyasn1.type import univ +from pyasn1.codec.ber import encoder +from pyasn1.compat.octets import int2oct, null + +class BooleanEncoder(encoder.IntegerEncoder): + def encodeValue(self, encodeFun, client, defMode, maxChunkSize): + if client == 0: + substrate = int2oct(0) + else: + substrate = int2oct(255) + return substrate, 0 + +class BitStringEncoder(encoder.BitStringEncoder): + def encodeValue(self, encodeFun, client, defMode, maxChunkSize): + return encoder.BitStringEncoder.encodeValue( + self, encodeFun, client, defMode, 1000 + ) + +class OctetStringEncoder(encoder.OctetStringEncoder): + def encodeValue(self, encodeFun, client, defMode, maxChunkSize): + return encoder.OctetStringEncoder.encodeValue( + self, encodeFun, client, defMode, 1000 + ) + +# specialized RealEncoder here +# specialized GeneralStringEncoder here +# specialized GeneralizedTimeEncoder here +# specialized UTCTimeEncoder here + +class SetOfEncoder(encoder.SequenceOfEncoder): + def encodeValue(self, encodeFun, client, defMode, maxChunkSize): + if isinstance(client, univ.SequenceAndSetBase): + client.setDefaultComponents() + client.verifySizeSpec() + substrate = null; idx = len(client) + # This is certainly a hack but how else do I distinguish SetOf + # from Set if they have the same tags&constraints? + if isinstance(client, univ.SequenceAndSetBase): + # Set + comps = [] + while idx > 0: + idx = idx - 1 + if client[idx] is None: # Optional component + continue + if client.getDefaultComponentByPosition(idx) == client[idx]: + continue + comps.append(client[idx]) + comps.sort(key=lambda x: isinstance(x, univ.Choice) and \ + x.getMinTagSet() or x.getTagSet()) + for c in comps: + substrate += encodeFun(c, defMode, maxChunkSize) + else: + # SetOf + compSubs = [] + while idx > 0: + idx = idx - 1 + compSubs.append( + encodeFun(client[idx], defMode, maxChunkSize) + ) + compSubs.sort() # perhaps padding's not needed + substrate = null + for compSub in compSubs: + substrate += compSub + return substrate, 1 + +tagMap = encoder.tagMap.copy() +tagMap.update({ + univ.Boolean.tagSet: BooleanEncoder(), + univ.BitString.tagSet: BitStringEncoder(), + univ.OctetString.tagSet: OctetStringEncoder(), + univ.SetOf().tagSet: SetOfEncoder() # conflcts with Set + }) + +typeMap = encoder.typeMap.copy() +typeMap.update({ + univ.Set.typeId: SetOfEncoder(), + univ.SetOf.typeId: SetOfEncoder() + }) + +class Encoder(encoder.Encoder): + def __call__(self, client, defMode=0, maxChunkSize=0): + return encoder.Encoder.__call__(self, client, defMode, maxChunkSize) + +encode = Encoder(tagMap, typeMap) + +# EncoderFactory queries class instance and builds a map of tags -> encoders diff --git a/src/lib/pyasn1/codec/der/__init__.py b/src/lib/pyasn1/codec/der/__init__.py new file mode 100644 index 00000000..8c3066b2 --- /dev/null +++ b/src/lib/pyasn1/codec/der/__init__.py @@ -0,0 +1 @@ +# This file is necessary to make this directory a package. diff --git a/src/lib/pyasn1/codec/der/decoder.py b/src/lib/pyasn1/codec/der/decoder.py new file mode 100644 index 00000000..604abec2 --- /dev/null +++ b/src/lib/pyasn1/codec/der/decoder.py @@ -0,0 +1,9 @@ +# DER decoder +from pyasn1.type import univ +from pyasn1.codec.cer import decoder + +tagMap = decoder.tagMap +typeMap = decoder.typeMap +Decoder = decoder.Decoder + +decode = Decoder(tagMap, typeMap) diff --git a/src/lib/pyasn1/codec/der/encoder.py b/src/lib/pyasn1/codec/der/encoder.py new file mode 100644 index 00000000..4e5faefa --- /dev/null +++ b/src/lib/pyasn1/codec/der/encoder.py @@ -0,0 +1,28 @@ +# DER encoder +from pyasn1.type import univ +from pyasn1.codec.cer import encoder + +class SetOfEncoder(encoder.SetOfEncoder): + def _cmpSetComponents(self, c1, c2): + tagSet1 = isinstance(c1, univ.Choice) and \ + c1.getEffectiveTagSet() or c1.getTagSet() + tagSet2 = isinstance(c2, univ.Choice) and \ + c2.getEffectiveTagSet() or c2.getTagSet() + return cmp(tagSet1, tagSet2) + +tagMap = encoder.tagMap.copy() +tagMap.update({ + # Overload CER encodrs with BER ones (a bit hackerish XXX) + univ.BitString.tagSet: encoder.encoder.BitStringEncoder(), + univ.OctetString.tagSet: encoder.encoder.OctetStringEncoder(), + # Set & SetOf have same tags + univ.SetOf().tagSet: SetOfEncoder() + }) + +typeMap = encoder.typeMap + +class Encoder(encoder.Encoder): + def __call__(self, client, defMode=1, maxChunkSize=0): + return encoder.Encoder.__call__(self, client, defMode, maxChunkSize) + +encode = Encoder(tagMap, typeMap) diff --git a/src/lib/pyasn1/compat/__init__.py b/src/lib/pyasn1/compat/__init__.py new file mode 100644 index 00000000..8c3066b2 --- /dev/null +++ b/src/lib/pyasn1/compat/__init__.py @@ -0,0 +1 @@ +# This file is necessary to make this directory a package. diff --git a/src/lib/pyasn1/compat/octets.py b/src/lib/pyasn1/compat/octets.py new file mode 100644 index 00000000..f7f2a29b --- /dev/null +++ b/src/lib/pyasn1/compat/octets.py @@ -0,0 +1,20 @@ +from sys import version_info + +if version_info[0] <= 2: + int2oct = chr + ints2octs = lambda s: ''.join([ int2oct(x) for x in s ]) + null = '' + oct2int = ord + octs2ints = lambda s: [ oct2int(x) for x in s ] + str2octs = lambda x: x + octs2str = lambda x: x + isOctetsType = lambda s: isinstance(s, str) +else: + ints2octs = bytes + int2oct = lambda x: ints2octs((x,)) + null = ints2octs() + oct2int = lambda x: x + octs2ints = lambda s: [ x for x in s ] + str2octs = lambda x: x.encode() + octs2str = lambda x: x.decode() + isOctetsType = lambda s: isinstance(s, bytes) diff --git a/src/lib/pyasn1/debug.py b/src/lib/pyasn1/debug.py new file mode 100644 index 00000000..c27cb1d4 --- /dev/null +++ b/src/lib/pyasn1/debug.py @@ -0,0 +1,65 @@ +import sys +from pyasn1.compat.octets import octs2ints +from pyasn1 import error +from pyasn1 import __version__ + +flagNone = 0x0000 +flagEncoder = 0x0001 +flagDecoder = 0x0002 +flagAll = 0xffff + +flagMap = { + 'encoder': flagEncoder, + 'decoder': flagDecoder, + 'all': flagAll + } + +class Debug: + defaultPrinter = sys.stderr.write + def __init__(self, *flags): + self._flags = flagNone + self._printer = self.defaultPrinter + self('running pyasn1 version %s' % __version__) + for f in flags: + if f not in flagMap: + raise error.PyAsn1Error('bad debug flag %s' % (f,)) + self._flags = self._flags | flagMap[f] + self('debug category \'%s\' enabled' % f) + + def __str__(self): + return 'logger %s, flags %x' % (self._printer, self._flags) + + def __call__(self, msg): + self._printer('DBG: %s\n' % msg) + + def __and__(self, flag): + return self._flags & flag + + def __rand__(self, flag): + return flag & self._flags + +logger = 0 + +def setLogger(l): + global logger + logger = l + +def hexdump(octets): + return ' '.join( + [ '%s%.2X' % (n%16 == 0 and ('\n%.5d: ' % n) or '', x) + for n,x in zip(range(len(octets)), octs2ints(octets)) ] + ) + +class Scope: + def __init__(self): + self._list = [] + + def __str__(self): return '.'.join(self._list) + + def push(self, token): + self._list.append(token) + + def pop(self): + return self._list.pop() + +scope = Scope() diff --git a/src/lib/pyasn1/error.py b/src/lib/pyasn1/error.py new file mode 100644 index 00000000..716406ff --- /dev/null +++ b/src/lib/pyasn1/error.py @@ -0,0 +1,3 @@ +class PyAsn1Error(Exception): pass +class ValueConstraintError(PyAsn1Error): pass +class SubstrateUnderrunError(PyAsn1Error): pass diff --git a/src/lib/pyasn1/type/__init__.py b/src/lib/pyasn1/type/__init__.py new file mode 100644 index 00000000..8c3066b2 --- /dev/null +++ b/src/lib/pyasn1/type/__init__.py @@ -0,0 +1 @@ +# This file is necessary to make this directory a package. diff --git a/src/lib/pyasn1/type/base.py b/src/lib/pyasn1/type/base.py new file mode 100644 index 00000000..40873719 --- /dev/null +++ b/src/lib/pyasn1/type/base.py @@ -0,0 +1,249 @@ +# Base classes for ASN.1 types +import sys +from pyasn1.type import constraint, tagmap +from pyasn1 import error + +class Asn1Item: pass + +class Asn1ItemBase(Asn1Item): + # Set of tags for this ASN.1 type + tagSet = () + + # A list of constraint.Constraint instances for checking values + subtypeSpec = constraint.ConstraintsIntersection() + + # Used for ambiguous ASN.1 types identification + typeId = None + + def __init__(self, tagSet=None, subtypeSpec=None): + if tagSet is None: + self._tagSet = self.tagSet + else: + self._tagSet = tagSet + if subtypeSpec is None: + self._subtypeSpec = self.subtypeSpec + else: + self._subtypeSpec = subtypeSpec + + def _verifySubtypeSpec(self, value, idx=None): + try: + self._subtypeSpec(value, idx) + except error.PyAsn1Error: + c, i, t = sys.exc_info() + raise c('%s at %s' % (i, self.__class__.__name__)) + + def getSubtypeSpec(self): return self._subtypeSpec + + def getTagSet(self): return self._tagSet + def getEffectiveTagSet(self): return self._tagSet # used by untagged types + def getTagMap(self): return tagmap.TagMap({self._tagSet: self}) + + def isSameTypeWith(self, other): + return self is other or \ + self._tagSet == other.getTagSet() and \ + self._subtypeSpec == other.getSubtypeSpec() + def isSuperTypeOf(self, other): + """Returns true if argument is a ASN1 subtype of ourselves""" + return self._tagSet.isSuperTagSetOf(other.getTagSet()) and \ + self._subtypeSpec.isSuperTypeOf(other.getSubtypeSpec()) + +class __NoValue: + def __getattr__(self, attr): + raise error.PyAsn1Error('No value for %s()' % attr) + def __getitem__(self, i): + raise error.PyAsn1Error('No value') + +noValue = __NoValue() + +# Base class for "simple" ASN.1 objects. These are immutable. +class AbstractSimpleAsn1Item(Asn1ItemBase): + defaultValue = noValue + def __init__(self, value=None, tagSet=None, subtypeSpec=None): + Asn1ItemBase.__init__(self, tagSet, subtypeSpec) + if value is None or value is noValue: + value = self.defaultValue + if value is None or value is noValue: + self.__hashedValue = value = noValue + else: + value = self.prettyIn(value) + self._verifySubtypeSpec(value) + self.__hashedValue = hash(value) + self._value = value + self._len = None + + def __repr__(self): + if self._value is noValue: + return self.__class__.__name__ + '()' + else: + return self.__class__.__name__ + '(%s)' % (self.prettyOut(self._value),) + def __str__(self): return str(self._value) + def __eq__(self, other): + return self is other and True or self._value == other + def __ne__(self, other): return self._value != other + def __lt__(self, other): return self._value < other + def __le__(self, other): return self._value <= other + def __gt__(self, other): return self._value > other + def __ge__(self, other): return self._value >= other + if sys.version_info[0] <= 2: + def __nonzero__(self): return bool(self._value) + else: + def __bool__(self): return bool(self._value) + def __hash__(self): return self.__hashedValue + + def clone(self, value=None, tagSet=None, subtypeSpec=None): + if value is None and tagSet is None and subtypeSpec is None: + return self + if value is None: + value = self._value + if tagSet is None: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + return self.__class__(value, tagSet, subtypeSpec) + + def subtype(self, value=None, implicitTag=None, explicitTag=None, + subtypeSpec=None): + if value is None: + value = self._value + if implicitTag is not None: + tagSet = self._tagSet.tagImplicitly(implicitTag) + elif explicitTag is not None: + tagSet = self._tagSet.tagExplicitly(explicitTag) + else: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + else: + subtypeSpec = subtypeSpec + self._subtypeSpec + return self.__class__(value, tagSet, subtypeSpec) + + def prettyIn(self, value): return value + def prettyOut(self, value): return str(value) + + def prettyPrint(self, scope=0): + if self._value is noValue: + return '' + else: + return self.prettyOut(self._value) + + # XXX Compatibility stub + def prettyPrinter(self, scope=0): return self.prettyPrint(scope) + +# +# Constructed types: +# * There are five of them: Sequence, SequenceOf/SetOf, Set and Choice +# * ASN1 types and values are represened by Python class instances +# * Value initialization is made for defaulted components only +# * Primary method of component addressing is by-position. Data model for base +# type is Python sequence. Additional type-specific addressing methods +# may be implemented for particular types. +# * SequenceOf and SetOf types do not implement any additional methods +# * Sequence, Set and Choice types also implement by-identifier addressing +# * Sequence, Set and Choice types also implement by-asn1-type (tag) addressing +# * Sequence and Set types may include optional and defaulted +# components +# * Constructed types hold a reference to component types used for value +# verification and ordering. +# * Component type is a scalar type for SequenceOf/SetOf types and a list +# of types for Sequence/Set/Choice. +# + +class AbstractConstructedAsn1Item(Asn1ItemBase): + componentType = None + sizeSpec = constraint.ConstraintsIntersection() + def __init__(self, componentType=None, tagSet=None, + subtypeSpec=None, sizeSpec=None): + Asn1ItemBase.__init__(self, tagSet, subtypeSpec) + if componentType is None: + self._componentType = self.componentType + else: + self._componentType = componentType + if sizeSpec is None: + self._sizeSpec = self.sizeSpec + else: + self._sizeSpec = sizeSpec + self._componentValues = [] + self._componentValuesSet = 0 + + def __repr__(self): + r = self.__class__.__name__ + '()' + for idx in range(len(self._componentValues)): + if self._componentValues[idx] is None: + continue + r = r + '.setComponentByPosition(%s, %r)' % ( + idx, self._componentValues[idx] + ) + return r + + def __eq__(self, other): + return self is other and True or self._componentValues == other + def __ne__(self, other): return self._componentValues != other + def __lt__(self, other): return self._componentValues < other + def __le__(self, other): return self._componentValues <= other + def __gt__(self, other): return self._componentValues > other + def __ge__(self, other): return self._componentValues >= other + if sys.version_info[0] <= 2: + def __nonzero__(self): return bool(self._componentValues) + else: + def __bool__(self): return bool(self._componentValues) + + def getComponentTagMap(self): + raise error.PyAsn1Error('Method not implemented') + + def _cloneComponentValues(self, myClone, cloneValueFlag): pass + + def clone(self, tagSet=None, subtypeSpec=None, sizeSpec=None, + cloneValueFlag=None): + if tagSet is None: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + if sizeSpec is None: + sizeSpec = self._sizeSpec + r = self.__class__(self._componentType, tagSet, subtypeSpec, sizeSpec) + if cloneValueFlag: + self._cloneComponentValues(r, cloneValueFlag) + return r + + def subtype(self, implicitTag=None, explicitTag=None, subtypeSpec=None, + sizeSpec=None, cloneValueFlag=None): + if implicitTag is not None: + tagSet = self._tagSet.tagImplicitly(implicitTag) + elif explicitTag is not None: + tagSet = self._tagSet.tagExplicitly(explicitTag) + else: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + else: + subtypeSpec = subtypeSpec + self._subtypeSpec + if sizeSpec is None: + sizeSpec = self._sizeSpec + else: + sizeSpec = sizeSpec + self._sizeSpec + r = self.__class__(self._componentType, tagSet, subtypeSpec, sizeSpec) + if cloneValueFlag: + self._cloneComponentValues(r, cloneValueFlag) + return r + + def _verifyComponent(self, idx, value): pass + + def verifySizeSpec(self): self._sizeSpec(self) + + def getComponentByPosition(self, idx): + raise error.PyAsn1Error('Method not implemented') + def setComponentByPosition(self, idx, value, verifyConstraints=True): + raise error.PyAsn1Error('Method not implemented') + + def getComponentType(self): return self._componentType + + def __getitem__(self, idx): return self.getComponentByPosition(idx) + def __setitem__(self, idx, value): self.setComponentByPosition(idx, value) + + def __len__(self): return len(self._componentValues) + + def clear(self): + self._componentValues = [] + self._componentValuesSet = 0 + + def setDefaultComponents(self): pass diff --git a/src/lib/pyasn1/type/char.py b/src/lib/pyasn1/type/char.py new file mode 100644 index 00000000..ae112f8b --- /dev/null +++ b/src/lib/pyasn1/type/char.py @@ -0,0 +1,61 @@ +# ASN.1 "character string" types +from pyasn1.type import univ, tag + +class UTF8String(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 12) + ) + encoding = "utf-8" + +class NumericString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 18) + ) + +class PrintableString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 19) + ) + +class TeletexString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 20) + ) + + +class VideotexString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 21) + ) + +class IA5String(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 22) + ) + +class GraphicString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 25) + ) + +class VisibleString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 26) + ) + +class GeneralString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 27) + ) + +class UniversalString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 28) + ) + encoding = "utf-32-be" + +class BMPString(univ.OctetString): + tagSet = univ.OctetString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 30) + ) + encoding = "utf-16-be" diff --git a/src/lib/pyasn1/type/constraint.py b/src/lib/pyasn1/type/constraint.py new file mode 100644 index 00000000..66873937 --- /dev/null +++ b/src/lib/pyasn1/type/constraint.py @@ -0,0 +1,200 @@ +# +# ASN.1 subtype constraints classes. +# +# Constraints are relatively rare, but every ASN1 object +# is doing checks all the time for whether they have any +# constraints and whether they are applicable to the object. +# +# What we're going to do is define objects/functions that +# can be called unconditionally if they are present, and that +# are simply not present if there are no constraints. +# +# Original concept and code by Mike C. Fletcher. +# +import sys +from pyasn1.type import error + +class AbstractConstraint: + """Abstract base-class for constraint objects + + Constraints should be stored in a simple sequence in the + namespace of their client Asn1Item sub-classes. + """ + def __init__(self, *values): + self._valueMap = {} + self._setValues(values) + self.__hashedValues = None + def __call__(self, value, idx=None): + try: + self._testValue(value, idx) + except error.ValueConstraintError: + raise error.ValueConstraintError( + '%s failed at: \"%s\"' % (self, sys.exc_info()[1]) + ) + def __repr__(self): + return '%s(%s)' % ( + self.__class__.__name__, + ', '.join([repr(x) for x in self._values]) + ) + def __eq__(self, other): + return self is other and True or self._values == other + def __ne__(self, other): return self._values != other + def __lt__(self, other): return self._values < other + def __le__(self, other): return self._values <= other + def __gt__(self, other): return self._values > other + def __ge__(self, other): return self._values >= other + if sys.version_info[0] <= 2: + def __nonzero__(self): return bool(self._values) + else: + def __bool__(self): return bool(self._values) + + def __hash__(self): + if self.__hashedValues is None: + self.__hashedValues = hash((self.__class__.__name__, self._values)) + return self.__hashedValues + + def _setValues(self, values): self._values = values + def _testValue(self, value, idx): + raise error.ValueConstraintError(value) + + # Constraints derivation logic + def getValueMap(self): return self._valueMap + def isSuperTypeOf(self, otherConstraint): + return self in otherConstraint.getValueMap() or \ + otherConstraint is self or otherConstraint == self + def isSubTypeOf(self, otherConstraint): + return otherConstraint in self._valueMap or \ + otherConstraint is self or otherConstraint == self + +class SingleValueConstraint(AbstractConstraint): + """Value must be part of defined values constraint""" + def _testValue(self, value, idx): + # XXX index vals for performance? + if value not in self._values: + raise error.ValueConstraintError(value) + +class ContainedSubtypeConstraint(AbstractConstraint): + """Value must satisfy all of defined set of constraints""" + def _testValue(self, value, idx): + for c in self._values: + c(value, idx) + +class ValueRangeConstraint(AbstractConstraint): + """Value must be within start and stop values (inclusive)""" + def _testValue(self, value, idx): + if value < self.start or value > self.stop: + raise error.ValueConstraintError(value) + + def _setValues(self, values): + if len(values) != 2: + raise error.PyAsn1Error( + '%s: bad constraint values' % (self.__class__.__name__,) + ) + self.start, self.stop = values + if self.start > self.stop: + raise error.PyAsn1Error( + '%s: screwed constraint values (start > stop): %s > %s' % ( + self.__class__.__name__, + self.start, self.stop + ) + ) + AbstractConstraint._setValues(self, values) + +class ValueSizeConstraint(ValueRangeConstraint): + """len(value) must be within start and stop values (inclusive)""" + def _testValue(self, value, idx): + l = len(value) + if l < self.start or l > self.stop: + raise error.ValueConstraintError(value) + +class PermittedAlphabetConstraint(SingleValueConstraint): + def _setValues(self, values): + self._values = () + for v in values: + self._values = self._values + tuple(v) + + def _testValue(self, value, idx): + for v in value: + if v not in self._values: + raise error.ValueConstraintError(value) + +# This is a bit kludgy, meaning two op modes within a single constraing +class InnerTypeConstraint(AbstractConstraint): + """Value must satisfy type and presense constraints""" + def _testValue(self, value, idx): + if self.__singleTypeConstraint: + self.__singleTypeConstraint(value) + elif self.__multipleTypeConstraint: + if idx not in self.__multipleTypeConstraint: + raise error.ValueConstraintError(value) + constraint, status = self.__multipleTypeConstraint[idx] + if status == 'ABSENT': # XXX presense is not checked! + raise error.ValueConstraintError(value) + constraint(value) + + def _setValues(self, values): + self.__multipleTypeConstraint = {} + self.__singleTypeConstraint = None + for v in values: + if isinstance(v, tuple): + self.__multipleTypeConstraint[v[0]] = v[1], v[2] + else: + self.__singleTypeConstraint = v + AbstractConstraint._setValues(self, values) + +# Boolean ops on constraints + +class ConstraintsExclusion(AbstractConstraint): + """Value must not fit the single constraint""" + def _testValue(self, value, idx): + try: + self._values[0](value, idx) + except error.ValueConstraintError: + return + else: + raise error.ValueConstraintError(value) + + def _setValues(self, values): + if len(values) != 1: + raise error.PyAsn1Error('Single constraint expected') + AbstractConstraint._setValues(self, values) + +class AbstractConstraintSet(AbstractConstraint): + """Value must not satisfy the single constraint""" + def __getitem__(self, idx): return self._values[idx] + + def __add__(self, value): return self.__class__(self, value) + def __radd__(self, value): return self.__class__(self, value) + + def __len__(self): return len(self._values) + + # Constraints inclusion in sets + + def _setValues(self, values): + self._values = values + for v in values: + self._valueMap[v] = 1 + self._valueMap.update(v.getValueMap()) + +class ConstraintsIntersection(AbstractConstraintSet): + """Value must satisfy all constraints""" + def _testValue(self, value, idx): + for v in self._values: + v(value, idx) + +class ConstraintsUnion(AbstractConstraintSet): + """Value must satisfy at least one constraint""" + def _testValue(self, value, idx): + for v in self._values: + try: + v(value, idx) + except error.ValueConstraintError: + pass + else: + return + raise error.ValueConstraintError( + 'all of %s failed for \"%s\"' % (self._values, value) + ) + +# XXX +# add tests for type check diff --git a/src/lib/pyasn1/type/error.py b/src/lib/pyasn1/type/error.py new file mode 100644 index 00000000..3e684844 --- /dev/null +++ b/src/lib/pyasn1/type/error.py @@ -0,0 +1,3 @@ +from pyasn1.error import PyAsn1Error + +class ValueConstraintError(PyAsn1Error): pass diff --git a/src/lib/pyasn1/type/namedtype.py b/src/lib/pyasn1/type/namedtype.py new file mode 100644 index 00000000..48967a5f --- /dev/null +++ b/src/lib/pyasn1/type/namedtype.py @@ -0,0 +1,132 @@ +# NamedType specification for constructed types +import sys +from pyasn1.type import tagmap +from pyasn1 import error + +class NamedType: + isOptional = 0 + isDefaulted = 0 + def __init__(self, name, t): + self.__name = name; self.__type = t + def __repr__(self): return '%s(%s, %s)' % ( + self.__class__.__name__, self.__name, self.__type + ) + def getType(self): return self.__type + def getName(self): return self.__name + def __getitem__(self, idx): + if idx == 0: return self.__name + if idx == 1: return self.__type + raise IndexError() + +class OptionalNamedType(NamedType): + isOptional = 1 +class DefaultedNamedType(NamedType): + isDefaulted = 1 + +class NamedTypes: + def __init__(self, *namedTypes): + self.__namedTypes = namedTypes + self.__namedTypesLen = len(self.__namedTypes) + self.__minTagSet = None + self.__tagToPosIdx = {}; self.__nameToPosIdx = {} + self.__tagMap = { False: None, True: None } + self.__ambigiousTypes = {} + + def __repr__(self): + r = '%s(' % self.__class__.__name__ + for n in self.__namedTypes: + r = r + '%r, ' % (n,) + return r + ')' + + def __getitem__(self, idx): return self.__namedTypes[idx] + + if sys.version_info[0] <= 2: + def __nonzero__(self): return bool(self.__namedTypesLen) + else: + def __bool__(self): return bool(self.__namedTypesLen) + def __len__(self): return self.__namedTypesLen + + def getTypeByPosition(self, idx): + if idx < 0 or idx >= self.__namedTypesLen: + raise error.PyAsn1Error('Type position out of range') + else: + return self.__namedTypes[idx].getType() + + def getPositionByType(self, tagSet): + if not self.__tagToPosIdx: + idx = self.__namedTypesLen + while idx > 0: + idx = idx - 1 + tagMap = self.__namedTypes[idx].getType().getTagMap() + for t in tagMap.getPosMap(): + if t in self.__tagToPosIdx: + raise error.PyAsn1Error('Duplicate type %s' % (t,)) + self.__tagToPosIdx[t] = idx + try: + return self.__tagToPosIdx[tagSet] + except KeyError: + raise error.PyAsn1Error('Type %s not found' % (tagSet,)) + + def getNameByPosition(self, idx): + try: + return self.__namedTypes[idx].getName() + except IndexError: + raise error.PyAsn1Error('Type position out of range') + def getPositionByName(self, name): + if not self.__nameToPosIdx: + idx = self.__namedTypesLen + while idx > 0: + idx = idx - 1 + n = self.__namedTypes[idx].getName() + if n in self.__nameToPosIdx: + raise error.PyAsn1Error('Duplicate name %s' % (n,)) + self.__nameToPosIdx[n] = idx + try: + return self.__nameToPosIdx[name] + except KeyError: + raise error.PyAsn1Error('Name %s not found' % (name,)) + + def __buildAmbigiousTagMap(self): + ambigiousTypes = () + idx = self.__namedTypesLen + while idx > 0: + idx = idx - 1 + t = self.__namedTypes[idx] + if t.isOptional or t.isDefaulted: + ambigiousTypes = (t, ) + ambigiousTypes + else: + ambigiousTypes = (t, ) + self.__ambigiousTypes[idx] = NamedTypes(*ambigiousTypes) + + def getTagMapNearPosition(self, idx): + if not self.__ambigiousTypes: self.__buildAmbigiousTagMap() + try: + return self.__ambigiousTypes[idx].getTagMap() + except KeyError: + raise error.PyAsn1Error('Type position out of range') + + def getPositionNearType(self, tagSet, idx): + if not self.__ambigiousTypes: self.__buildAmbigiousTagMap() + try: + return idx+self.__ambigiousTypes[idx].getPositionByType(tagSet) + except KeyError: + raise error.PyAsn1Error('Type position out of range') + + def genMinTagSet(self): + if self.__minTagSet is None: + for t in self.__namedTypes: + __type = t.getType() + tagSet = getattr(__type,'getMinTagSet',__type.getTagSet)() + if self.__minTagSet is None or tagSet < self.__minTagSet: + self.__minTagSet = tagSet + return self.__minTagSet + + def getTagMap(self, uniq=False): + if self.__tagMap[uniq] is None: + tagMap = tagmap.TagMap() + for nt in self.__namedTypes: + tagMap = tagMap.clone( + nt.getType(), nt.getType().getTagMap(), uniq + ) + self.__tagMap[uniq] = tagMap + return self.__tagMap[uniq] diff --git a/src/lib/pyasn1/type/namedval.py b/src/lib/pyasn1/type/namedval.py new file mode 100644 index 00000000..d0fea7cc --- /dev/null +++ b/src/lib/pyasn1/type/namedval.py @@ -0,0 +1,46 @@ +# ASN.1 named integers +from pyasn1 import error + +__all__ = [ 'NamedValues' ] + +class NamedValues: + def __init__(self, *namedValues): + self.nameToValIdx = {}; self.valToNameIdx = {} + self.namedValues = () + automaticVal = 1 + for namedValue in namedValues: + if isinstance(namedValue, tuple): + name, val = namedValue + else: + name = namedValue + val = automaticVal + if name in self.nameToValIdx: + raise error.PyAsn1Error('Duplicate name %s' % (name,)) + self.nameToValIdx[name] = val + if val in self.valToNameIdx: + raise error.PyAsn1Error('Duplicate value %s=%s' % (name, val)) + self.valToNameIdx[val] = name + self.namedValues = self.namedValues + ((name, val),) + automaticVal = automaticVal + 1 + def __str__(self): return str(self.namedValues) + + def getName(self, value): + if value in self.valToNameIdx: + return self.valToNameIdx[value] + + def getValue(self, name): + if name in self.nameToValIdx: + return self.nameToValIdx[name] + + def __getitem__(self, i): return self.namedValues[i] + def __len__(self): return len(self.namedValues) + + def __add__(self, namedValues): + return self.__class__(*self.namedValues + namedValues) + def __radd__(self, namedValues): + return self.__class__(*namedValues + tuple(self)) + + def clone(self, *namedValues): + return self.__class__(*tuple(self) + namedValues) + +# XXX clone/subtype? diff --git a/src/lib/pyasn1/type/tag.py b/src/lib/pyasn1/type/tag.py new file mode 100644 index 00000000..1144907f --- /dev/null +++ b/src/lib/pyasn1/type/tag.py @@ -0,0 +1,122 @@ +# ASN.1 types tags +from operator import getitem +from pyasn1 import error + +tagClassUniversal = 0x00 +tagClassApplication = 0x40 +tagClassContext = 0x80 +tagClassPrivate = 0xC0 + +tagFormatSimple = 0x00 +tagFormatConstructed = 0x20 + +tagCategoryImplicit = 0x01 +tagCategoryExplicit = 0x02 +tagCategoryUntagged = 0x04 + +class Tag: + def __init__(self, tagClass, tagFormat, tagId): + if tagId < 0: + raise error.PyAsn1Error( + 'Negative tag ID (%s) not allowed' % (tagId,) + ) + self.__tag = (tagClass, tagFormat, tagId) + self.uniq = (tagClass, tagId) + self.__hashedUniqTag = hash(self.uniq) + + def __repr__(self): + return '%s(tagClass=%s, tagFormat=%s, tagId=%s)' % ( + (self.__class__.__name__,) + self.__tag + ) + # These is really a hotspot -- expose public "uniq" attribute to save on + # function calls + def __eq__(self, other): return self.uniq == other.uniq + def __ne__(self, other): return self.uniq != other.uniq + def __lt__(self, other): return self.uniq < other.uniq + def __le__(self, other): return self.uniq <= other.uniq + def __gt__(self, other): return self.uniq > other.uniq + def __ge__(self, other): return self.uniq >= other.uniq + def __hash__(self): return self.__hashedUniqTag + def __getitem__(self, idx): return self.__tag[idx] + def __and__(self, otherTag): + (tagClass, tagFormat, tagId) = otherTag + return self.__class__( + self.__tag&tagClass, self.__tag&tagFormat, self.__tag&tagId + ) + def __or__(self, otherTag): + (tagClass, tagFormat, tagId) = otherTag + return self.__class__( + self.__tag[0]|tagClass, + self.__tag[1]|tagFormat, + self.__tag[2]|tagId + ) + def asTuple(self): return self.__tag # __getitem__() is slow + +class TagSet: + def __init__(self, baseTag=(), *superTags): + self.__baseTag = baseTag + self.__superTags = superTags + self.__hashedSuperTags = hash(superTags) + _uniq = () + for t in superTags: + _uniq = _uniq + t.uniq + self.uniq = _uniq + self.__lenOfSuperTags = len(superTags) + + def __repr__(self): + return '%s(%s)' % ( + self.__class__.__name__, + ', '.join([repr(x) for x in self.__superTags]) + ) + + def __add__(self, superTag): + return self.__class__( + self.__baseTag, *self.__superTags + (superTag,) + ) + def __radd__(self, superTag): + return self.__class__( + self.__baseTag, *(superTag,) + self.__superTags + ) + + def tagExplicitly(self, superTag): + tagClass, tagFormat, tagId = superTag + if tagClass == tagClassUniversal: + raise error.PyAsn1Error( + 'Can\'t tag with UNIVERSAL-class tag' + ) + if tagFormat != tagFormatConstructed: + superTag = Tag(tagClass, tagFormatConstructed, tagId) + return self + superTag + + def tagImplicitly(self, superTag): + tagClass, tagFormat, tagId = superTag + if self.__superTags: + superTag = Tag(tagClass, self.__superTags[-1][1], tagId) + return self[:-1] + superTag + + def getBaseTag(self): return self.__baseTag + def __getitem__(self, idx): + if isinstance(idx, slice): + return self.__class__( + self.__baseTag, *getitem(self.__superTags, idx) + ) + return self.__superTags[idx] + def __eq__(self, other): return self.uniq == other.uniq + def __ne__(self, other): return self.uniq != other.uniq + def __lt__(self, other): return self.uniq < other.uniq + def __le__(self, other): return self.uniq <= other.uniq + def __gt__(self, other): return self.uniq > other.uniq + def __ge__(self, other): return self.uniq >= other.uniq + def __hash__(self): return self.__hashedSuperTags + def __len__(self): return self.__lenOfSuperTags + def isSuperTagSetOf(self, tagSet): + if len(tagSet) < self.__lenOfSuperTags: + return + idx = self.__lenOfSuperTags - 1 + while idx >= 0: + if self.__superTags[idx] != tagSet[idx]: + return + idx = idx - 1 + return 1 + +def initTagSet(tag): return TagSet(tag, tag) diff --git a/src/lib/pyasn1/type/tagmap.py b/src/lib/pyasn1/type/tagmap.py new file mode 100644 index 00000000..7cec3a10 --- /dev/null +++ b/src/lib/pyasn1/type/tagmap.py @@ -0,0 +1,52 @@ +from pyasn1 import error + +class TagMap: + def __init__(self, posMap={}, negMap={}, defType=None): + self.__posMap = posMap.copy() + self.__negMap = negMap.copy() + self.__defType = defType + + def __contains__(self, tagSet): + return tagSet in self.__posMap or \ + self.__defType is not None and tagSet not in self.__negMap + + def __getitem__(self, tagSet): + if tagSet in self.__posMap: + return self.__posMap[tagSet] + elif tagSet in self.__negMap: + raise error.PyAsn1Error('Key in negative map') + elif self.__defType is not None: + return self.__defType + else: + raise KeyError() + + def __repr__(self): + s = '%r/%r' % (self.__posMap, self.__negMap) + if self.__defType is not None: + s = s + '/%r' % (self.__defType,) + return s + + def clone(self, parentType, tagMap, uniq=False): + if self.__defType is not None and tagMap.getDef() is not None: + raise error.PyAsn1Error('Duplicate default value at %s' % (self,)) + if tagMap.getDef() is not None: + defType = tagMap.getDef() + else: + defType = self.__defType + + posMap = self.__posMap.copy() + for k in tagMap.getPosMap(): + if uniq and k in posMap: + raise error.PyAsn1Error('Duplicate positive key %s' % (k,)) + posMap[k] = parentType + + negMap = self.__negMap.copy() + negMap.update(tagMap.getNegMap()) + + return self.__class__( + posMap, negMap, defType, + ) + + def getPosMap(self): return self.__posMap.copy() + def getNegMap(self): return self.__negMap.copy() + def getDef(self): return self.__defType diff --git a/src/lib/pyasn1/type/univ.py b/src/lib/pyasn1/type/univ.py new file mode 100644 index 00000000..9cd16f8a --- /dev/null +++ b/src/lib/pyasn1/type/univ.py @@ -0,0 +1,1042 @@ +# ASN.1 "universal" data types +import operator, sys +from pyasn1.type import base, tag, constraint, namedtype, namedval, tagmap +from pyasn1.codec.ber import eoo +from pyasn1.compat import octets +from pyasn1 import error + +# "Simple" ASN.1 types (yet incomplete) + +class Integer(base.AbstractSimpleAsn1Item): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x02) + ) + namedValues = namedval.NamedValues() + def __init__(self, value=None, tagSet=None, subtypeSpec=None, + namedValues=None): + if namedValues is None: + self.__namedValues = self.namedValues + else: + self.__namedValues = namedValues + base.AbstractSimpleAsn1Item.__init__( + self, value, tagSet, subtypeSpec + ) + + def __and__(self, value): return self.clone(self._value & value) + def __rand__(self, value): return self.clone(value & self._value) + def __or__(self, value): return self.clone(self._value | value) + def __ror__(self, value): return self.clone(value | self._value) + def __xor__(self, value): return self.clone(self._value ^ value) + def __rxor__(self, value): return self.clone(value ^ self._value) + def __lshift__(self, value): return self.clone(self._value << value) + def __rshift__(self, value): return self.clone(self._value >> value) + + def __add__(self, value): return self.clone(self._value + value) + def __radd__(self, value): return self.clone(value + self._value) + def __sub__(self, value): return self.clone(self._value - value) + def __rsub__(self, value): return self.clone(value - self._value) + def __mul__(self, value): return self.clone(self._value * value) + def __rmul__(self, value): return self.clone(value * self._value) + def __mod__(self, value): return self.clone(self._value % value) + def __rmod__(self, value): return self.clone(value % self._value) + def __pow__(self, value, modulo=None): return self.clone(pow(self._value, value, modulo)) + def __rpow__(self, value): return self.clone(pow(value, self._value)) + + if sys.version_info[0] <= 2: + def __div__(self, value): return self.clone(self._value // value) + def __rdiv__(self, value): return self.clone(value // self._value) + else: + def __truediv__(self, value): return self.clone(self._value / value) + def __rtruediv__(self, value): return self.clone(value / self._value) + def __divmod__(self, value): return self.clone(self._value // value) + def __rdivmod__(self, value): return self.clone(value // self._value) + + __hash__ = base.AbstractSimpleAsn1Item.__hash__ + + def __int__(self): return int(self._value) + if sys.version_info[0] <= 2: + def __long__(self): return long(self._value) + def __float__(self): return float(self._value) + def __abs__(self): return abs(self._value) + def __index__(self): return int(self._value) + + def __lt__(self, value): return self._value < value + def __le__(self, value): return self._value <= value + def __eq__(self, value): return self._value == value + def __ne__(self, value): return self._value != value + def __gt__(self, value): return self._value > value + def __ge__(self, value): return self._value >= value + + def prettyIn(self, value): + if not isinstance(value, str): + try: + return int(value) + except: + raise error.PyAsn1Error( + 'Can\'t coerce %s into integer: %s' % (value, sys.exc_info()[1]) + ) + r = self.__namedValues.getValue(value) + if r is not None: + return r + try: + return int(value) + except: + raise error.PyAsn1Error( + 'Can\'t coerce %s into integer: %s' % (value, sys.exc_info()[1]) + ) + + def prettyOut(self, value): + r = self.__namedValues.getName(value) + return r is None and str(value) or repr(r) + + def getNamedValues(self): return self.__namedValues + + def clone(self, value=None, tagSet=None, subtypeSpec=None, + namedValues=None): + if value is None and tagSet is None and subtypeSpec is None \ + and namedValues is None: + return self + if value is None: + value = self._value + if tagSet is None: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + if namedValues is None: + namedValues = self.__namedValues + return self.__class__(value, tagSet, subtypeSpec, namedValues) + + def subtype(self, value=None, implicitTag=None, explicitTag=None, + subtypeSpec=None, namedValues=None): + if value is None: + value = self._value + if implicitTag is not None: + tagSet = self._tagSet.tagImplicitly(implicitTag) + elif explicitTag is not None: + tagSet = self._tagSet.tagExplicitly(explicitTag) + else: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + else: + subtypeSpec = subtypeSpec + self._subtypeSpec + if namedValues is None: + namedValues = self.__namedValues + else: + namedValues = namedValues + self.__namedValues + return self.__class__(value, tagSet, subtypeSpec, namedValues) + +class Boolean(Integer): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x01), + ) + subtypeSpec = Integer.subtypeSpec+constraint.SingleValueConstraint(0,1) + namedValues = Integer.namedValues.clone(('False', 0), ('True', 1)) + +class BitString(base.AbstractSimpleAsn1Item): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x03) + ) + namedValues = namedval.NamedValues() + def __init__(self, value=None, tagSet=None, subtypeSpec=None, + namedValues=None): + if namedValues is None: + self.__namedValues = self.namedValues + else: + self.__namedValues = namedValues + base.AbstractSimpleAsn1Item.__init__( + self, value, tagSet, subtypeSpec + ) + + def clone(self, value=None, tagSet=None, subtypeSpec=None, + namedValues=None): + if value is None and tagSet is None and subtypeSpec is None \ + and namedValues is None: + return self + if value is None: + value = self._value + if tagSet is None: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + if namedValues is None: + namedValues = self.__namedValues + return self.__class__(value, tagSet, subtypeSpec, namedValues) + + def subtype(self, value=None, implicitTag=None, explicitTag=None, + subtypeSpec=None, namedValues=None): + if value is None: + value = self._value + if implicitTag is not None: + tagSet = self._tagSet.tagImplicitly(implicitTag) + elif explicitTag is not None: + tagSet = self._tagSet.tagExplicitly(explicitTag) + else: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + else: + subtypeSpec = subtypeSpec + self._subtypeSpec + if namedValues is None: + namedValues = self.__namedValues + else: + namedValues = namedValues + self.__namedValues + return self.__class__(value, tagSet, subtypeSpec, namedValues) + + def __str__(self): return str(tuple(self)) + + # Immutable sequence object protocol + + def __len__(self): + if self._len is None: + self._len = len(self._value) + return self._len + def __getitem__(self, i): + if isinstance(i, slice): + return self.clone(operator.getitem(self._value, i)) + else: + return self._value[i] + + def __add__(self, value): return self.clone(self._value + value) + def __radd__(self, value): return self.clone(value + self._value) + def __mul__(self, value): return self.clone(self._value * value) + def __rmul__(self, value): return self * value + + def prettyIn(self, value): + r = [] + if not value: + return () + elif isinstance(value, str): + if value[0] == '\'': + if value[-2:] == '\'B': + for v in value[1:-2]: + if v == '0': + r.append(0) + elif v == '1': + r.append(1) + else: + raise error.PyAsn1Error( + 'Non-binary BIT STRING initializer %s' % (v,) + ) + return tuple(r) + elif value[-2:] == '\'H': + for v in value[1:-2]: + i = 4 + v = int(v, 16) + while i: + i = i - 1 + r.append((v>>i)&0x01) + return tuple(r) + else: + raise error.PyAsn1Error( + 'Bad BIT STRING value notation %s' % (value,) + ) + else: + for i in value.split(','): + j = self.__namedValues.getValue(i) + if j is None: + raise error.PyAsn1Error( + 'Unknown bit identifier \'%s\'' % (i,) + ) + if j >= len(r): + r.extend([0]*(j-len(r)+1)) + r[j] = 1 + return tuple(r) + elif isinstance(value, (tuple, list)): + r = tuple(value) + for b in r: + if b and b != 1: + raise error.PyAsn1Error( + 'Non-binary BitString initializer \'%s\'' % (r,) + ) + return r + elif isinstance(value, BitString): + return tuple(value) + else: + raise error.PyAsn1Error( + 'Bad BitString initializer type \'%s\'' % (value,) + ) + + def prettyOut(self, value): + return '\"\'%s\'B\"' % ''.join([str(x) for x in value]) + +class OctetString(base.AbstractSimpleAsn1Item): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x04) + ) + defaultBinValue = defaultHexValue = base.noValue + encoding = 'us-ascii' + def __init__(self, value=None, tagSet=None, subtypeSpec=None, + encoding=None, binValue=None, hexValue=None): + if encoding is None: + self._encoding = self.encoding + else: + self._encoding = encoding + if binValue is not None: + value = self.fromBinaryString(binValue) + if hexValue is not None: + value = self.fromHexString(hexValue) + if value is None or value is base.noValue: + value = self.defaultHexValue + if value is None or value is base.noValue: + value = self.defaultBinValue + self.__intValue = None + base.AbstractSimpleAsn1Item.__init__(self, value, tagSet, subtypeSpec) + + def clone(self, value=None, tagSet=None, subtypeSpec=None, + encoding=None, binValue=None, hexValue=None): + if value is None and tagSet is None and subtypeSpec is None and \ + encoding is None and binValue is None and hexValue is None: + return self + if value is None and binValue is None and hexValue is None: + value = self._value + if tagSet is None: + tagSet = self._tagSet + if subtypeSpec is None: + subtypeSpec = self._subtypeSpec + if encoding is None: + encoding = self._encoding + return self.__class__( + value, tagSet, subtypeSpec, encoding, binValue, hexValue + ) + + if sys.version_info[0] <= 2: + def prettyIn(self, value): + if isinstance(value, str): + return value + elif isinstance(value, (tuple, list)): + try: + return ''.join([ chr(x) for x in value ]) + except ValueError: + raise error.PyAsn1Error( + 'Bad OctetString initializer \'%s\'' % (value,) + ) + else: + return str(value) + else: + def prettyIn(self, value): + if isinstance(value, bytes): + return value + elif isinstance(value, OctetString): + return value.asOctets() + elif isinstance(value, (tuple, list, map)): + try: + return bytes(value) + except ValueError: + raise error.PyAsn1Error( + 'Bad OctetString initializer \'%s\'' % (value,) + ) + else: + try: + return str(value).encode(self._encoding) + except UnicodeEncodeError: + raise error.PyAsn1Error( + 'Can\'t encode string \'%s\' with \'%s\' codec' % (value, self._encoding) + ) + + + def fromBinaryString(self, value): + bitNo = 8; byte = 0; r = () + for v in value: + if bitNo: + bitNo = bitNo - 1 + else: + bitNo = 7 + r = r + (byte,) + byte = 0 + if v == '0': + v = 0 + elif v == '1': + v = 1 + else: + raise error.PyAsn1Error( + 'Non-binary OCTET STRING initializer %s' % (v,) + ) + byte = byte | (v << bitNo) + return octets.ints2octs(r + (byte,)) + + def fromHexString(self, value): + r = p = () + for v in value: + if p: + r = r + (int(p+v, 16),) + p = () + else: + p = v + if p: + r = r + (int(p+'0', 16),) + return octets.ints2octs(r) + + def prettyOut(self, value): + if sys.version_info[0] <= 2: + numbers = tuple([ ord(x) for x in value ]) + else: + numbers = tuple(value) + if [ x for x in numbers if x < 32 or x > 126 ]: + return '0x' + ''.join([ '%.2x' % x for x in numbers ]) + else: + return str(value) + + def __repr__(self): + if self._value is base.noValue: + return self.__class__.__name__ + '()' + if [ x for x in self.asNumbers() if x < 32 or x > 126 ]: + return self.__class__.__name__ + '(hexValue=\'' + ''.join([ '%.2x' % x for x in self.asNumbers() ])+'\')' + else: + return self.__class__.__name__ + '(\'' + self.prettyOut(self._value) + '\')' + + if sys.version_info[0] <= 2: + def __str__(self): return str(self._value) + def __unicode__(self): + return self._value.decode(self._encoding, 'ignore') + def asOctets(self): return self._value + def asNumbers(self): + if self.__intValue is None: + self.__intValue = tuple([ ord(x) for x in self._value ]) + return self.__intValue + else: + def __str__(self): return self._value.decode(self._encoding, 'ignore') + def __bytes__(self): return self._value + def asOctets(self): return self._value + def asNumbers(self): + if self.__intValue is None: + self.__intValue = tuple(self._value) + return self.__intValue + + # Immutable sequence object protocol + + def __len__(self): + if self._len is None: + self._len = len(self._value) + return self._len + def __getitem__(self, i): + if isinstance(i, slice): + return self.clone(operator.getitem(self._value, i)) + else: + return self._value[i] + + def __add__(self, value): return self.clone(self._value + self.prettyIn(value)) + def __radd__(self, value): return self.clone(self.prettyIn(value) + self._value) + def __mul__(self, value): return self.clone(self._value * value) + def __rmul__(self, value): return self * value + +class Null(OctetString): + defaultValue = ''.encode() # This is tightly constrained + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x05) + ) + subtypeSpec = OctetString.subtypeSpec+constraint.SingleValueConstraint(''.encode()) + +if sys.version_info[0] <= 2: + intTypes = (int, long) +else: + intTypes = int + +class ObjectIdentifier(base.AbstractSimpleAsn1Item): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x06) + ) + def __add__(self, other): return self.clone(self._value + other) + def __radd__(self, other): return self.clone(other + self._value) + + def asTuple(self): return self._value + + # Sequence object protocol + + def __len__(self): + if self._len is None: + self._len = len(self._value) + return self._len + def __getitem__(self, i): + if isinstance(i, slice): + return self.clone( + operator.getitem(self._value, i) + ) + else: + return self._value[i] + + def __str__(self): return self.prettyPrint() + + def index(self, suboid): return self._value.index(suboid) + + def isPrefixOf(self, value): + """Returns true if argument OID resides deeper in the OID tree""" + l = len(self) + if l <= len(value): + if self._value[:l] == value[:l]: + return 1 + return 0 + + def prettyIn(self, value): + """Dotted -> tuple of numerics OID converter""" + if isinstance(value, tuple): + pass + elif isinstance(value, ObjectIdentifier): + return tuple(value) + elif isinstance(value, str): + r = [] + for element in [ x for x in value.split('.') if x != '' ]: + try: + r.append(int(element, 0)) + except ValueError: + raise error.PyAsn1Error( + 'Malformed Object ID %s at %s: %s' % + (str(value), self.__class__.__name__, sys.exc_info()[1]) + ) + value = tuple(r) + else: + try: + value = tuple(value) + except TypeError: + raise error.PyAsn1Error( + 'Malformed Object ID %s at %s: %s' % + (str(value), self.__class__.__name__,sys.exc_info()[1]) + ) + + for x in value: + if not isinstance(x, intTypes) or x < 0: + raise error.PyAsn1Error( + 'Invalid sub-ID in %s at %s' % (value, self.__class__.__name__) + ) + + return value + + def prettyOut(self, value): return '.'.join([ str(x) for x in value ]) + +class Real(base.AbstractSimpleAsn1Item): + try: + _plusInf = float('inf') + _minusInf = float('-inf') + _inf = (_plusInf, _minusInf) + except ValueError: + # Infinity support is platform and Python dependent + _plusInf = _minusInf = None + _inf = () + + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x09) + ) + + def __normalizeBase10(self, value): + m, b, e = value + while m and m % 10 == 0: + m = m / 10 + e = e + 1 + return m, b, e + + def prettyIn(self, value): + if isinstance(value, tuple) and len(value) == 3: + for d in value: + if not isinstance(d, intTypes): + raise error.PyAsn1Error( + 'Lame Real value syntax: %s' % (value,) + ) + if value[1] not in (2, 10): + raise error.PyAsn1Error( + 'Prohibited base for Real value: %s' % (value[1],) + ) + if value[1] == 10: + value = self.__normalizeBase10(value) + return value + elif isinstance(value, intTypes): + return self.__normalizeBase10((value, 10, 0)) + elif isinstance(value, float): + if self._inf and value in self._inf: + return value + else: + e = 0 + while int(value) != value: + value = value * 10 + e = e - 1 + return self.__normalizeBase10((int(value), 10, e)) + elif isinstance(value, Real): + return tuple(value) + elif isinstance(value, str): # handle infinite literal + try: + return float(value) + except ValueError: + pass + raise error.PyAsn1Error( + 'Bad real value syntax: %s' % (value,) + ) + + def prettyOut(self, value): + if value in self._inf: + return '\'%s\'' % value + else: + return str(value) + + def isPlusInfinity(self): return self._value == self._plusInf + def isMinusInfinity(self): return self._value == self._minusInf + def isInfinity(self): return self._value in self._inf + + def __str__(self): return str(float(self)) + + def __add__(self, value): return self.clone(float(self) + value) + def __radd__(self, value): return self + value + def __mul__(self, value): return self.clone(float(self) * value) + def __rmul__(self, value): return self * value + def __sub__(self, value): return self.clone(float(self) - value) + def __rsub__(self, value): return self.clone(value - float(self)) + def __mod__(self, value): return self.clone(float(self) % value) + def __rmod__(self, value): return self.clone(value % float(self)) + def __pow__(self, value, modulo=None): return self.clone(pow(float(self), value, modulo)) + def __rpow__(self, value): return self.clone(pow(value, float(self))) + + if sys.version_info[0] <= 2: + def __div__(self, value): return self.clone(float(self) / value) + def __rdiv__(self, value): return self.clone(value / float(self)) + else: + def __truediv__(self, value): return self.clone(float(self) / value) + def __rtruediv__(self, value): return self.clone(value / float(self)) + def __divmod__(self, value): return self.clone(float(self) // value) + def __rdivmod__(self, value): return self.clone(value // float(self)) + + def __int__(self): return int(float(self)) + if sys.version_info[0] <= 2: + def __long__(self): return long(float(self)) + def __float__(self): + if self._value in self._inf: + return self._value + else: + return float( + self._value[0] * pow(self._value[1], self._value[2]) + ) + def __abs__(self): return abs(float(self)) + + def __lt__(self, value): return float(self) < value + def __le__(self, value): return float(self) <= value + def __eq__(self, value): return float(self) == value + def __ne__(self, value): return float(self) != value + def __gt__(self, value): return float(self) > value + def __ge__(self, value): return float(self) >= value + + if sys.version_info[0] <= 2: + def __nonzero__(self): return bool(float(self)) + else: + def __bool__(self): return bool(float(self)) + __hash__ = base.AbstractSimpleAsn1Item.__hash__ + + def __getitem__(self, idx): + if self._value in self._inf: + raise error.PyAsn1Error('Invalid infinite value operation') + else: + return self._value[idx] + +class Enumerated(Integer): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 0x0A) + ) + +# "Structured" ASN.1 types + +class SetOf(base.AbstractConstructedAsn1Item): + componentType = None + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatConstructed, 0x11) + ) + typeId = 1 + + def _cloneComponentValues(self, myClone, cloneValueFlag): + idx = 0; l = len(self._componentValues) + while idx < l: + c = self._componentValues[idx] + if c is not None: + if isinstance(c, base.AbstractConstructedAsn1Item): + myClone.setComponentByPosition( + idx, c.clone(cloneValueFlag=cloneValueFlag) + ) + else: + myClone.setComponentByPosition(idx, c.clone()) + idx = idx + 1 + + def _verifyComponent(self, idx, value): + if self._componentType is not None and \ + not self._componentType.isSuperTypeOf(value): + raise error.PyAsn1Error('Component type error %s' % (value,)) + + def getComponentByPosition(self, idx): return self._componentValues[idx] + def setComponentByPosition(self, idx, value=None, verifyConstraints=True): + l = len(self._componentValues) + if idx >= l: + self._componentValues = self._componentValues + (idx-l+1)*[None] + if value is None: + if self._componentValues[idx] is None: + if self._componentType is None: + raise error.PyAsn1Error('Component type not defined') + self._componentValues[idx] = self._componentType.clone() + self._componentValuesSet = self._componentValuesSet + 1 + return self + elif not isinstance(value, base.Asn1Item): + if self._componentType is None: + raise error.PyAsn1Error('Component type not defined') + if isinstance(self._componentType, base.AbstractSimpleAsn1Item): + value = self._componentType.clone(value=value) + else: + raise error.PyAsn1Error('Instance value required') + if verifyConstraints: + if self._componentType is not None: + self._verifyComponent(idx, value) + self._verifySubtypeSpec(value, idx) + if self._componentValues[idx] is None: + self._componentValuesSet = self._componentValuesSet + 1 + self._componentValues[idx] = value + return self + + def getComponentTagMap(self): + if self._componentType is not None: + return self._componentType.getTagMap() + + def prettyPrint(self, scope=0): + scope = scope + 1 + r = self.__class__.__name__ + ':\n' + for idx in range(len(self._componentValues)): + r = r + ' '*scope + if self._componentValues[idx] is None: + r = r + '' + else: + r = r + self._componentValues[idx].prettyPrint(scope) + return r + +class SequenceOf(SetOf): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatConstructed, 0x10) + ) + typeId = 2 + +class SequenceAndSetBase(base.AbstractConstructedAsn1Item): + componentType = namedtype.NamedTypes() + def __init__(self, componentType=None, tagSet=None, + subtypeSpec=None, sizeSpec=None): + base.AbstractConstructedAsn1Item.__init__( + self, componentType, tagSet, subtypeSpec, sizeSpec + ) + if self._componentType is None: + self._componentTypeLen = 0 + else: + self._componentTypeLen = len(self._componentType) + + def __getitem__(self, idx): + if isinstance(idx, str): + return self.getComponentByName(idx) + else: + return base.AbstractConstructedAsn1Item.__getitem__(self, idx) + + def __setitem__(self, idx, value): + if isinstance(idx, str): + self.setComponentByName(idx, value) + else: + base.AbstractConstructedAsn1Item.__setitem__(self, idx, value) + + def _cloneComponentValues(self, myClone, cloneValueFlag): + idx = 0; l = len(self._componentValues) + while idx < l: + c = self._componentValues[idx] + if c is not None: + if isinstance(c, base.AbstractConstructedAsn1Item): + myClone.setComponentByPosition( + idx, c.clone(cloneValueFlag=cloneValueFlag) + ) + else: + myClone.setComponentByPosition(idx, c.clone()) + idx = idx + 1 + + def _verifyComponent(self, idx, value): + if idx >= self._componentTypeLen: + raise error.PyAsn1Error( + 'Component type error out of range' + ) + t = self._componentType[idx].getType() + if not t.isSuperTypeOf(value): + raise error.PyAsn1Error('Component type error %r vs %r' % (t, value)) + + def getComponentByName(self, name): + return self.getComponentByPosition( + self._componentType.getPositionByName(name) + ) + def setComponentByName(self, name, value=None, verifyConstraints=True): + return self.setComponentByPosition( + self._componentType.getPositionByName(name), value, + verifyConstraints + ) + + def getComponentByPosition(self, idx): + try: + return self._componentValues[idx] + except IndexError: + if idx < self._componentTypeLen: + return + raise + def setComponentByPosition(self, idx, value=None, verifyConstraints=True): + l = len(self._componentValues) + if idx >= l: + self._componentValues = self._componentValues + (idx-l+1)*[None] + if value is None: + if self._componentValues[idx] is None: + self._componentValues[idx] = self._componentType.getTypeByPosition(idx).clone() + self._componentValuesSet = self._componentValuesSet + 1 + return self + elif not isinstance(value, base.Asn1Item): + t = self._componentType.getTypeByPosition(idx) + if isinstance(t, base.AbstractSimpleAsn1Item): + value = t.clone(value=value) + else: + raise error.PyAsn1Error('Instance value required') + if verifyConstraints: + if self._componentTypeLen: + self._verifyComponent(idx, value) + self._verifySubtypeSpec(value, idx) + if self._componentValues[idx] is None: + self._componentValuesSet = self._componentValuesSet + 1 + self._componentValues[idx] = value + return self + + def getNameByPosition(self, idx): + if self._componentTypeLen: + return self._componentType.getNameByPosition(idx) + + def getDefaultComponentByPosition(self, idx): + if self._componentTypeLen and self._componentType[idx].isDefaulted: + return self._componentType[idx].getType() + + def getComponentType(self): + if self._componentTypeLen: + return self._componentType + + def setDefaultComponents(self): + if self._componentTypeLen == self._componentValuesSet: + return + idx = self._componentTypeLen + while idx: + idx = idx - 1 + if self._componentType[idx].isDefaulted: + if self.getComponentByPosition(idx) is None: + self.setComponentByPosition(idx) + elif not self._componentType[idx].isOptional: + if self.getComponentByPosition(idx) is None: + raise error.PyAsn1Error( + 'Uninitialized component #%s at %r' % (idx, self) + ) + + def prettyPrint(self, scope=0): + scope = scope + 1 + r = self.__class__.__name__ + ':\n' + for idx in range(len(self._componentValues)): + if self._componentValues[idx] is not None: + r = r + ' '*scope + componentType = self.getComponentType() + if componentType is None: + r = r + '' + else: + r = r + componentType.getNameByPosition(idx) + r = '%s=%s\n' % ( + r, self._componentValues[idx].prettyPrint(scope) + ) + return r + +class Sequence(SequenceAndSetBase): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatConstructed, 0x10) + ) + typeId = 3 + + def getComponentTagMapNearPosition(self, idx): + if self._componentType: + return self._componentType.getTagMapNearPosition(idx) + + def getComponentPositionNearType(self, tagSet, idx): + if self._componentType: + return self._componentType.getPositionNearType(tagSet, idx) + else: + return idx + +class Set(SequenceAndSetBase): + tagSet = baseTagSet = tag.initTagSet( + tag.Tag(tag.tagClassUniversal, tag.tagFormatConstructed, 0x11) + ) + typeId = 4 + + def getComponent(self, innerFlag=0): return self + + def getComponentByType(self, tagSet, innerFlag=0): + c = self.getComponentByPosition( + self._componentType.getPositionByType(tagSet) + ) + if innerFlag and isinstance(c, Set): + # get inner component by inner tagSet + return c.getComponent(1) + else: + # get outer component by inner tagSet + return c + + def setComponentByType(self, tagSet, value=None, innerFlag=0, + verifyConstraints=True): + idx = self._componentType.getPositionByType(tagSet) + t = self._componentType.getTypeByPosition(idx) + if innerFlag: # set inner component by inner tagSet + if t.getTagSet(): + return self.setComponentByPosition( + idx, value, verifyConstraints + ) + else: + t = self.setComponentByPosition(idx).getComponentByPosition(idx) + return t.setComponentByType( + tagSet, value, innerFlag, verifyConstraints + ) + else: # set outer component by inner tagSet + return self.setComponentByPosition( + idx, value, verifyConstraints + ) + + def getComponentTagMap(self): + if self._componentType: + return self._componentType.getTagMap(True) + + def getComponentPositionByType(self, tagSet): + if self._componentType: + return self._componentType.getPositionByType(tagSet) + +class Choice(Set): + tagSet = baseTagSet = tag.TagSet() # untagged + sizeSpec = constraint.ConstraintsIntersection( + constraint.ValueSizeConstraint(1, 1) + ) + typeId = 5 + _currentIdx = None + + def __eq__(self, other): + if self._componentValues: + return self._componentValues[self._currentIdx] == other + return NotImplemented + def __ne__(self, other): + if self._componentValues: + return self._componentValues[self._currentIdx] != other + return NotImplemented + def __lt__(self, other): + if self._componentValues: + return self._componentValues[self._currentIdx] < other + return NotImplemented + def __le__(self, other): + if self._componentValues: + return self._componentValues[self._currentIdx] <= other + return NotImplemented + def __gt__(self, other): + if self._componentValues: + return self._componentValues[self._currentIdx] > other + return NotImplemented + def __ge__(self, other): + if self._componentValues: + return self._componentValues[self._currentIdx] >= other + return NotImplemented + if sys.version_info[0] <= 2: + def __nonzero__(self): return bool(self._componentValues) + else: + def __bool__(self): return bool(self._componentValues) + + def __len__(self): return self._currentIdx is not None and 1 or 0 + + def verifySizeSpec(self): + if self._currentIdx is None: + raise error.PyAsn1Error('Component not chosen') + else: + self._sizeSpec(' ') + + def _cloneComponentValues(self, myClone, cloneValueFlag): + try: + c = self.getComponent() + except error.PyAsn1Error: + pass + else: + if isinstance(c, Choice): + tagSet = c.getEffectiveTagSet() + else: + tagSet = c.getTagSet() + if isinstance(c, base.AbstractConstructedAsn1Item): + myClone.setComponentByType( + tagSet, c.clone(cloneValueFlag=cloneValueFlag) + ) + else: + myClone.setComponentByType(tagSet, c.clone()) + + def setComponentByPosition(self, idx, value=None, verifyConstraints=True): + l = len(self._componentValues) + if idx >= l: + self._componentValues = self._componentValues + (idx-l+1)*[None] + if self._currentIdx is not None: + self._componentValues[self._currentIdx] = None + if value is None: + if self._componentValues[idx] is None: + self._componentValues[idx] = self._componentType.getTypeByPosition(idx).clone() + self._componentValuesSet = 1 + self._currentIdx = idx + return self + elif not isinstance(value, base.Asn1Item): + value = self._componentType.getTypeByPosition(idx).clone( + value=value + ) + if verifyConstraints: + if self._componentTypeLen: + self._verifyComponent(idx, value) + self._verifySubtypeSpec(value, idx) + self._componentValues[idx] = value + self._currentIdx = idx + self._componentValuesSet = 1 + return self + + def getMinTagSet(self): + if self._tagSet: + return self._tagSet + else: + return self._componentType.genMinTagSet() + + def getEffectiveTagSet(self): + if self._tagSet: + return self._tagSet + else: + c = self.getComponent() + if isinstance(c, Choice): + return c.getEffectiveTagSet() + else: + return c.getTagSet() + + def getTagMap(self): + if self._tagSet: + return Set.getTagMap(self) + else: + return Set.getComponentTagMap(self) + + def getComponent(self, innerFlag=0): + if self._currentIdx is None: + raise error.PyAsn1Error('Component not chosen') + else: + c = self._componentValues[self._currentIdx] + if innerFlag and isinstance(c, Choice): + return c.getComponent(innerFlag) + else: + return c + + def getName(self, innerFlag=0): + if self._currentIdx is None: + raise error.PyAsn1Error('Component not chosen') + else: + if innerFlag: + c = self._componentValues[self._currentIdx] + if isinstance(c, Choice): + return c.getName(innerFlag) + return self._componentType.getNameByPosition(self._currentIdx) + + def setDefaultComponents(self): pass + +class Any(OctetString): + tagSet = baseTagSet = tag.TagSet() # untagged + typeId = 6 + + def getTagMap(self): + return tagmap.TagMap( + { self.getTagSet(): self }, + { eoo.endOfOctets.getTagSet(): eoo.endOfOctets }, + self + ) + +# XXX +# coercion rules? diff --git a/src/lib/pyasn1/type/useful.py b/src/lib/pyasn1/type/useful.py new file mode 100644 index 00000000..a7139c22 --- /dev/null +++ b/src/lib/pyasn1/type/useful.py @@ -0,0 +1,12 @@ +# ASN.1 "useful" types +from pyasn1.type import char, tag + +class GeneralizedTime(char.VisibleString): + tagSet = char.VisibleString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 24) + ) + +class UTCTime(char.VisibleString): + tagSet = char.VisibleString.tagSet.tagImplicitly( + tag.Tag(tag.tagClassUniversal, tag.tagFormatSimple, 23) + ) diff --git a/src/lib/rsa/CHANGELOG.txt b/src/lib/rsa/CHANGELOG.txt new file mode 100644 index 00000000..2d8f5cf9 --- /dev/null +++ b/src/lib/rsa/CHANGELOG.txt @@ -0,0 +1,55 @@ +Python-RSA changelog +======================================== + +Version 3.1.1 - in development +---------------------------------------- + +- Fixed doctests for Python 2.7 +- Removed obsolete unittest so all tests run fine on Python 3.2 + +Version 3.1 - released 2012-06-17 +---------------------------------------- + +- Big, big credits to Yesudeep Mangalapilly for all the changes listed + below! +- Added ability to generate keys on multiple cores simultaneously. +- Massive speedup +- Partial Python 3.2 compatibility (core functionality works, but + saving or loading keys doesn't, for that the pyasn1 package needs to + be ported to Python 3 first) +- Lots of bug fixes + + + +Version 3.0.1 - released 2011-08-07 +---------------------------------------- + +- Removed unused import of abc module + + +Version 3.0 - released 2011-08-05 +---------------------------------------- + +- Changed the meaning of the keysize to mean the size of ``n`` rather than + the size of both ``p`` and ``q``. This is the common interpretation of + RSA keysize. To get the old behaviour, double the keysize when generating a + new key. + +- Added a lot of doctests + +- Added random-padded encryption and decryption using PKCS#1 version 1.5 + +- Added hash-based signatures and verification using PKCS#1v1.5 + +- Modeling private and public key as real objects rather than dicts. + +- Support for saving and loading keys as PEM and DER files. + +- Ability to extract a public key from a private key (PEM+DER) + + +Version 2.0 +---------------------------------------- + +- Security improvements by Barry Mead. + diff --git a/src/lib/rsa/LICENSE b/src/lib/rsa/LICENSE new file mode 100644 index 00000000..da76c9d7 --- /dev/null +++ b/src/lib/rsa/LICENSE @@ -0,0 +1,13 @@ +Copyright 2011 Sybren A. Stüvel + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. diff --git a/src/lib/rsa/README.rst b/src/lib/rsa/README.rst new file mode 100644 index 00000000..9f348636 --- /dev/null +++ b/src/lib/rsa/README.rst @@ -0,0 +1,31 @@ +Pure Python RSA implementation +============================== + +`Python-RSA`_ is a pure-Python RSA implementation. It supports +encryption and decryption, signing and verifying signatures, and key +generation according to PKCS#1 version 1.5. It can be used as a Python +library as well as on the commandline. The code was mostly written by +Sybren A. Stüvel. + +Documentation can be found at the Python-RSA homepage: +http://stuvel.eu/rsa + +Download and install using:: + + pip install rsa + +or:: + + easy_install rsa + +or download it from the `Python Package Index`_. + +The source code is maintained in a `Mercurial repository`_ and is +licensed under the `Apache License, version 2.0`_ + + +.. _`Python-RSA`: http://stuvel.eu/rsa +.. _`Mercurial repository`: https://bitbucket.org/sybren/python-rsa +.. _`Python Package Index`: http://pypi.python.org/pypi/rsa +.. _`Apache License, version 2.0`: http://www.apache.org/licenses/LICENSE-2.0 + diff --git a/src/lib/rsa/__init__.py b/src/lib/rsa/__init__.py new file mode 100644 index 00000000..99fd6689 --- /dev/null +++ b/src/lib/rsa/__init__.py @@ -0,0 +1,45 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RSA module + +Module for calculating large primes, and RSA encryption, decryption, signing +and verification. Includes generating public and private keys. + +WARNING: this implementation does not use random padding, compression of the +cleartext input to prevent repetitions, or other common security improvements. +Use with care. + +If you want to have a more secure implementation, use the functions from the +``rsa.pkcs1`` module. + +""" + +__author__ = "Sybren Stuvel, Barry Mead and Yesudeep Mangalapilly" +__date__ = "2015-11-05" +__version__ = '3.2.3' + +from rsa.key import newkeys, PrivateKey, PublicKey +from rsa.pkcs1 import encrypt, decrypt, sign, verify, DecryptionError, \ + VerificationError + +# Do doctest if we're run directly +if __name__ == "__main__": + import doctest + doctest.testmod() + +__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify", 'PublicKey', + 'PrivateKey', 'DecryptionError', 'VerificationError'] + diff --git a/src/lib/rsa/_compat.py b/src/lib/rsa/_compat.py new file mode 100644 index 00000000..3c4eb81b --- /dev/null +++ b/src/lib/rsa/_compat.py @@ -0,0 +1,160 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Python compatibility wrappers.""" + + +from __future__ import absolute_import + +import sys +from struct import pack + +try: + MAX_INT = sys.maxsize +except AttributeError: + MAX_INT = sys.maxint + +MAX_INT64 = (1 << 63) - 1 +MAX_INT32 = (1 << 31) - 1 +MAX_INT16 = (1 << 15) - 1 + +# Determine the word size of the processor. +if MAX_INT == MAX_INT64: + # 64-bit processor. + MACHINE_WORD_SIZE = 64 +elif MAX_INT == MAX_INT32: + # 32-bit processor. + MACHINE_WORD_SIZE = 32 +else: + # Else we just assume 64-bit processor keeping up with modern times. + MACHINE_WORD_SIZE = 64 + + +try: + # < Python3 + unicode_type = unicode + have_python3 = False +except NameError: + # Python3. + unicode_type = str + have_python3 = True + +# Fake byte literals. +if str is unicode_type: + def byte_literal(s): + return s.encode('latin1') +else: + def byte_literal(s): + return s + +# ``long`` is no more. Do type detection using this instead. +try: + integer_types = (int, long) +except NameError: + integer_types = (int,) + +b = byte_literal + +try: + # Python 2.6 or higher. + bytes_type = bytes +except NameError: + # Python 2.5 + bytes_type = str + + +# To avoid calling b() multiple times in tight loops. +ZERO_BYTE = b('\x00') +EMPTY_BYTE = b('') + + +def is_bytes(obj): + """ + Determines whether the given value is a byte string. + + :param obj: + The value to test. + :returns: + ``True`` if ``value`` is a byte string; ``False`` otherwise. + """ + return isinstance(obj, bytes_type) + + +def is_integer(obj): + """ + Determines whether the given value is an integer. + + :param obj: + The value to test. + :returns: + ``True`` if ``value`` is an integer; ``False`` otherwise. + """ + return isinstance(obj, integer_types) + + +def byte(num): + """ + Converts a number between 0 and 255 (both inclusive) to a base-256 (byte) + representation. + + Use it as a replacement for ``chr`` where you are expecting a byte + because this will work on all current versions of Python:: + + :param num: + An unsigned integer between 0 and 255 (both inclusive). + :returns: + A single byte. + """ + return pack("B", num) + + +def get_word_alignment(num, force_arch=64, + _machine_word_size=MACHINE_WORD_SIZE): + """ + Returns alignment details for the given number based on the platform + Python is running on. + + :param num: + Unsigned integral number. + :param force_arch: + If you don't want to use 64-bit unsigned chunks, set this to + anything other than 64. 32-bit chunks will be preferred then. + Default 64 will be used when on a 64-bit machine. + :param _machine_word_size: + (Internal) The machine word size used for alignment. + :returns: + 4-tuple:: + + (word_bits, word_bytes, + max_uint, packing_format_type) + """ + max_uint64 = 0xffffffffffffffff + max_uint32 = 0xffffffff + max_uint16 = 0xffff + max_uint8 = 0xff + + if force_arch == 64 and _machine_word_size >= 64 and num > max_uint32: + # 64-bit unsigned integer. + return 64, 8, max_uint64, "Q" + elif num > max_uint16: + # 32-bit unsigned integer + return 32, 4, max_uint32, "L" + elif num > max_uint8: + # 16-bit unsigned integer. + return 16, 2, max_uint16, "H" + else: + # 8-bit unsigned integer. + return 8, 1, max_uint8, "B" diff --git a/src/lib/rsa/_version133.py b/src/lib/rsa/_version133.py new file mode 100644 index 00000000..dff0dda8 --- /dev/null +++ b/src/lib/rsa/_version133.py @@ -0,0 +1,458 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RSA module +pri = k[1] //Private part of keys d,p,q + +Module for calculating large primes, and RSA encryption, decryption, +signing and verification. Includes generating public and private keys. + +WARNING: this code implements the mathematics of RSA. It is not suitable for +real-world secure cryptography purposes. It has not been reviewed by a security +expert. It does not include padding of data. There are many ways in which the +output of this module, when used without any modification, can be sucessfully +attacked. +""" + +__author__ = "Sybren Stuvel, Marloes de Boer and Ivo Tamboer" +__date__ = "2010-02-05" +__version__ = '1.3.3' + +# NOTE: Python's modulo can return negative numbers. We compensate for +# this behaviour using the abs() function + +from cPickle import dumps, loads +import base64 +import math +import os +import random +import sys +import types +import zlib + +from rsa._compat import byte + +# Display a warning that this insecure version is imported. +import warnings +warnings.warn('Insecure version of the RSA module is imported as %s, be careful' + % __name__) + +def gcd(p, q): + """Returns the greatest common divisor of p and q + + + >>> gcd(42, 6) + 6 + """ + if p>> (128*256 + 64)*256 + + 15 + 8405007 + >>> l = [128, 64, 15] + >>> bytes2int(l) + 8405007 + """ + + if not (type(bytes) is types.ListType or type(bytes) is types.StringType): + raise TypeError("You must pass a string or a list") + + # Convert byte stream to integer + integer = 0 + for byte in bytes: + integer *= 256 + if type(byte) is types.StringType: byte = ord(byte) + integer += byte + + return integer + +def int2bytes(number): + """Converts a number to a string of bytes + + >>> bytes2int(int2bytes(123456789)) + 123456789 + """ + + if not (type(number) is types.LongType or type(number) is types.IntType): + raise TypeError("You must pass a long or an int") + + string = "" + + while number > 0: + string = "%s%s" % (byte(number & 0xFF), string) + number /= 256 + + return string + +def fast_exponentiation(a, p, n): + """Calculates r = a^p mod n + """ + result = a % n + remainders = [] + while p != 1: + remainders.append(p & 1) + p = p >> 1 + while remainders: + rem = remainders.pop() + result = ((a ** rem) * result ** 2) % n + return result + +def read_random_int(nbits): + """Reads a random integer of approximately nbits bits rounded up + to whole bytes""" + + nbytes = ceil(nbits/8.) + randomdata = os.urandom(nbytes) + return bytes2int(randomdata) + +def ceil(x): + """ceil(x) -> int(math.ceil(x))""" + + return int(math.ceil(x)) + +def randint(minvalue, maxvalue): + """Returns a random integer x with minvalue <= x <= maxvalue""" + + # Safety - get a lot of random data even if the range is fairly + # small + min_nbits = 32 + + # The range of the random numbers we need to generate + range = maxvalue - minvalue + + # Which is this number of bytes + rangebytes = ceil(math.log(range, 2) / 8.) + + # Convert to bits, but make sure it's always at least min_nbits*2 + rangebits = max(rangebytes * 8, min_nbits * 2) + + # Take a random number of bits between min_nbits and rangebits + nbits = random.randint(min_nbits, rangebits) + + return (read_random_int(nbits) % range) + minvalue + +def fermat_little_theorem(p): + """Returns 1 if p may be prime, and something else if p definitely + is not prime""" + + a = randint(1, p-1) + return fast_exponentiation(a, p-1, p) + +def jacobi(a, b): + """Calculates the value of the Jacobi symbol (a/b) + """ + + if a % b == 0: + return 0 + result = 1 + while a > 1: + if a & 1: + if ((a-1)*(b-1) >> 2) & 1: + result = -result + b, a = a, b % a + else: + if ((b ** 2 - 1) >> 3) & 1: + result = -result + a = a >> 1 + return result + +def jacobi_witness(x, n): + """Returns False if n is an Euler pseudo-prime with base x, and + True otherwise. + """ + + j = jacobi(x, n) % n + f = fast_exponentiation(x, (n-1)/2, n) + + if j == f: return False + return True + +def randomized_primality_testing(n, k): + """Calculates whether n is composite (which is always correct) or + prime (which is incorrect with error probability 2**-k) + + Returns False if the number if composite, and True if it's + probably prime. + """ + + q = 0.5 # Property of the jacobi_witness function + + # t = int(math.ceil(k / math.log(1/q, 2))) + t = ceil(k / math.log(1/q, 2)) + for i in range(t+1): + x = randint(1, n-1) + if jacobi_witness(x, n): return False + + return True + +def is_prime(number): + """Returns True if the number is prime, and False otherwise. + + >>> is_prime(42) + 0 + >>> is_prime(41) + 1 + """ + + """ + if not fermat_little_theorem(number) == 1: + # Not prime, according to Fermat's little theorem + return False + """ + + if randomized_primality_testing(number, 5): + # Prime, according to Jacobi + return True + + # Not prime + return False + + +def getprime(nbits): + """Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In + other words: nbits is rounded up to whole bytes. + + >>> p = getprime(8) + >>> is_prime(p-1) + 0 + >>> is_prime(p) + 1 + >>> is_prime(p+1) + 0 + """ + + nbytes = int(math.ceil(nbits/8.)) + + while True: + integer = read_random_int(nbits) + + # Make sure it's odd + integer |= 1 + + # Test for primeness + if is_prime(integer): break + + # Retry if not prime + + return integer + +def are_relatively_prime(a, b): + """Returns True if a and b are relatively prime, and False if they + are not. + + >>> are_relatively_prime(2, 3) + 1 + >>> are_relatively_prime(2, 4) + 0 + """ + + d = gcd(a, b) + return (d == 1) + +def find_p_q(nbits): + """Returns a tuple of two different primes of nbits bits""" + + p = getprime(nbits) + while True: + q = getprime(nbits) + if not q == p: break + + return (p, q) + +def extended_euclid_gcd(a, b): + """Returns a tuple (d, i, j) such that d = gcd(a, b) = ia + jb + """ + + if b == 0: + return (a, 1, 0) + + q = abs(a % b) + r = long(a / b) + (d, k, l) = extended_euclid_gcd(b, q) + + return (d, l, k - l*r) + +# Main function: calculate encryption and decryption keys +def calculate_keys(p, q, nbits): + """Calculates an encryption and a decryption key for p and q, and + returns them as a tuple (e, d)""" + + n = p * q + phi_n = (p-1) * (q-1) + + while True: + # Make sure e has enough bits so we ensure "wrapping" through + # modulo n + e = getprime(max(8, nbits/2)) + if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break + + (d, i, j) = extended_euclid_gcd(e, phi_n) + + if not d == 1: + raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n)) + + if not (e * i) % phi_n == 1: + raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n)) + + return (e, i) + + +def gen_keys(nbits): + """Generate RSA keys of nbits bits. Returns (p, q, e, d). + + Note: this can take a long time, depending on the key size. + """ + + while True: + (p, q) = find_p_q(nbits) + (e, d) = calculate_keys(p, q, nbits) + + # For some reason, d is sometimes negative. We don't know how + # to fix it (yet), so we keep trying until everything is shiny + if d > 0: break + + return (p, q, e, d) + +def gen_pubpriv_keys(nbits): + """Generates public and private keys, and returns them as (pub, + priv). + + The public key consists of a dict {e: ..., , n: ....). The private + key consists of a dict {d: ...., p: ...., q: ....). + """ + + (p, q, e, d) = gen_keys(nbits) + + return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} ) + +def encrypt_int(message, ekey, n): + """Encrypts a message using encryption key 'ekey', working modulo + n""" + + if type(message) is types.IntType: + return encrypt_int(long(message), ekey, n) + + if not type(message) is types.LongType: + raise TypeError("You must pass a long or an int") + + if message > 0 and \ + math.floor(math.log(message, 2)) > math.floor(math.log(n, 2)): + raise OverflowError("The message is too long") + + return fast_exponentiation(message, ekey, n) + +def decrypt_int(cyphertext, dkey, n): + """Decrypts a cypher text using the decryption key 'dkey', working + modulo n""" + + return encrypt_int(cyphertext, dkey, n) + +def sign_int(message, dkey, n): + """Signs 'message' using key 'dkey', working modulo n""" + + return decrypt_int(message, dkey, n) + +def verify_int(signed, ekey, n): + """verifies 'signed' using key 'ekey', working modulo n""" + + return encrypt_int(signed, ekey, n) + +def picklechops(chops): + """Pickles and base64encodes it's argument chops""" + + value = zlib.compress(dumps(chops)) + encoded = base64.encodestring(value) + return encoded.strip() + +def unpicklechops(string): + """base64decodes and unpickes it's argument string into chops""" + + return loads(zlib.decompress(base64.decodestring(string))) + +def chopstring(message, key, n, funcref): + """Splits 'message' into chops that are at most as long as n, + converts these into integers, and calls funcref(integer, key, n) + for each chop. + + Used by 'encrypt' and 'sign'. + """ + + msglen = len(message) + mbits = msglen * 8 + nbits = int(math.floor(math.log(n, 2))) + nbytes = nbits / 8 + blocks = msglen / nbytes + + if msglen % nbytes > 0: + blocks += 1 + + cypher = [] + + for bindex in range(blocks): + offset = bindex * nbytes + block = message[offset:offset+nbytes] + value = bytes2int(block) + cypher.append(funcref(value, key, n)) + + return picklechops(cypher) + +def gluechops(chops, key, n, funcref): + """Glues chops back together into a string. calls + funcref(integer, key, n) for each chop. + + Used by 'decrypt' and 'verify'. + """ + message = "" + + chops = unpicklechops(chops) + + for cpart in chops: + mpart = funcref(cpart, key, n) + message += int2bytes(mpart) + + return message + +def encrypt(message, key): + """Encrypts a string 'message' with the public key 'key'""" + + return chopstring(message, key['e'], key['n'], encrypt_int) + +def sign(message, key): + """Signs a string 'message' with the private key 'key'""" + + return chopstring(message, key['d'], key['p']*key['q'], decrypt_int) + +def decrypt(cypher, key): + """Decrypts a cypher with the private key 'key'""" + + return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int) + +def verify(cypher, key): + """Verifies a cypher with the public key 'key'""" + + return gluechops(cypher, key['e'], key['n'], encrypt_int) + +# Do doctest if we're not imported +if __name__ == "__main__": + import doctest + doctest.testmod() + +__all__ = ["gen_pubpriv_keys", "encrypt", "decrypt", "sign", "verify"] + diff --git a/src/lib/rsa/_version200.py b/src/lib/rsa/_version200.py new file mode 100644 index 00000000..28f36018 --- /dev/null +++ b/src/lib/rsa/_version200.py @@ -0,0 +1,545 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""RSA module + +Module for calculating large primes, and RSA encryption, decryption, +signing and verification. Includes generating public and private keys. + +WARNING: this implementation does not use random padding, compression of the +cleartext input to prevent repetitions, or other common security improvements. +Use with care. + +""" + +__author__ = "Sybren Stuvel, Marloes de Boer, Ivo Tamboer, and Barry Mead" +__date__ = "2010-02-08" +__version__ = '2.0' + +import math +import os +import random +import sys +import types +from rsa._compat import byte + +# Display a warning that this insecure version is imported. +import warnings +warnings.warn('Insecure version of the RSA module is imported as %s' % __name__) + + +def bit_size(number): + """Returns the number of bits required to hold a specific long number""" + + return int(math.ceil(math.log(number,2))) + +def gcd(p, q): + """Returns the greatest common divisor of p and q + >>> gcd(48, 180) + 12 + """ + # Iterateive Version is faster and uses much less stack space + while q != 0: + if p < q: (p,q) = (q,p) + (p,q) = (q, p % q) + return p + + +def bytes2int(bytes): + """Converts a list of bytes or a string to an integer + + >>> (((128 * 256) + 64) * 256) + 15 + 8405007 + >>> l = [128, 64, 15] + >>> bytes2int(l) #same as bytes2int('\x80@\x0f') + 8405007 + """ + + if not (type(bytes) is types.ListType or type(bytes) is types.StringType): + raise TypeError("You must pass a string or a list") + + # Convert byte stream to integer + integer = 0 + for byte in bytes: + integer *= 256 + if type(byte) is types.StringType: byte = ord(byte) + integer += byte + + return integer + +def int2bytes(number): + """ + Converts a number to a string of bytes + """ + + if not (type(number) is types.LongType or type(number) is types.IntType): + raise TypeError("You must pass a long or an int") + + string = "" + + while number > 0: + string = "%s%s" % (byte(number & 0xFF), string) + number /= 256 + + return string + +def to64(number): + """Converts a number in the range of 0 to 63 into base 64 digit + character in the range of '0'-'9', 'A'-'Z', 'a'-'z','-','_'. + + >>> to64(10) + 'A' + """ + + if not (type(number) is types.LongType or type(number) is types.IntType): + raise TypeError("You must pass a long or an int") + + if 0 <= number <= 9: #00-09 translates to '0' - '9' + return byte(number + 48) + + if 10 <= number <= 35: + return byte(number + 55) #10-35 translates to 'A' - 'Z' + + if 36 <= number <= 61: + return byte(number + 61) #36-61 translates to 'a' - 'z' + + if number == 62: # 62 translates to '-' (minus) + return byte(45) + + if number == 63: # 63 translates to '_' (underscore) + return byte(95) + + raise ValueError('Invalid Base64 value: %i' % number) + + +def from64(number): + """Converts an ordinal character value in the range of + 0-9,A-Z,a-z,-,_ to a number in the range of 0-63. + + >>> from64(49) + 1 + """ + + if not (type(number) is types.LongType or type(number) is types.IntType): + raise TypeError("You must pass a long or an int") + + if 48 <= number <= 57: #ord('0') - ord('9') translates to 0-9 + return(number - 48) + + if 65 <= number <= 90: #ord('A') - ord('Z') translates to 10-35 + return(number - 55) + + if 97 <= number <= 122: #ord('a') - ord('z') translates to 36-61 + return(number - 61) + + if number == 45: #ord('-') translates to 62 + return(62) + + if number == 95: #ord('_') translates to 63 + return(63) + + raise ValueError('Invalid Base64 value: %i' % number) + + +def int2str64(number): + """Converts a number to a string of base64 encoded characters in + the range of '0'-'9','A'-'Z,'a'-'z','-','_'. + + >>> int2str64(123456789) + '7MyqL' + """ + + if not (type(number) is types.LongType or type(number) is types.IntType): + raise TypeError("You must pass a long or an int") + + string = "" + + while number > 0: + string = "%s%s" % (to64(number & 0x3F), string) + number /= 64 + + return string + + +def str642int(string): + """Converts a base64 encoded string into an integer. + The chars of this string in in the range '0'-'9','A'-'Z','a'-'z','-','_' + + >>> str642int('7MyqL') + 123456789 + """ + + if not (type(string) is types.ListType or type(string) is types.StringType): + raise TypeError("You must pass a string or a list") + + integer = 0 + for byte in string: + integer *= 64 + if type(byte) is types.StringType: byte = ord(byte) + integer += from64(byte) + + return integer + +def read_random_int(nbits): + """Reads a random integer of approximately nbits bits rounded up + to whole bytes""" + + nbytes = int(math.ceil(nbits/8.)) + randomdata = os.urandom(nbytes) + return bytes2int(randomdata) + +def randint(minvalue, maxvalue): + """Returns a random integer x with minvalue <= x <= maxvalue""" + + # Safety - get a lot of random data even if the range is fairly + # small + min_nbits = 32 + + # The range of the random numbers we need to generate + range = (maxvalue - minvalue) + 1 + + # Which is this number of bytes + rangebytes = ((bit_size(range) + 7) / 8) + + # Convert to bits, but make sure it's always at least min_nbits*2 + rangebits = max(rangebytes * 8, min_nbits * 2) + + # Take a random number of bits between min_nbits and rangebits + nbits = random.randint(min_nbits, rangebits) + + return (read_random_int(nbits) % range) + minvalue + +def jacobi(a, b): + """Calculates the value of the Jacobi symbol (a/b) + where both a and b are positive integers, and b is odd + """ + + if a == 0: return 0 + result = 1 + while a > 1: + if a & 1: + if ((a-1)*(b-1) >> 2) & 1: + result = -result + a, b = b % a, a + else: + if (((b * b) - 1) >> 3) & 1: + result = -result + a >>= 1 + if a == 0: return 0 + return result + +def jacobi_witness(x, n): + """Returns False if n is an Euler pseudo-prime with base x, and + True otherwise. + """ + + j = jacobi(x, n) % n + f = pow(x, (n-1)/2, n) + + if j == f: return False + return True + +def randomized_primality_testing(n, k): + """Calculates whether n is composite (which is always correct) or + prime (which is incorrect with error probability 2**-k) + + Returns False if the number is composite, and True if it's + probably prime. + """ + + # 50% of Jacobi-witnesses can report compositness of non-prime numbers + + for i in range(k): + x = randint(1, n-1) + if jacobi_witness(x, n): return False + + return True + +def is_prime(number): + """Returns True if the number is prime, and False otherwise. + + >>> is_prime(42) + 0 + >>> is_prime(41) + 1 + """ + + if randomized_primality_testing(number, 6): + # Prime, according to Jacobi + return True + + # Not prime + return False + + +def getprime(nbits): + """Returns a prime number of max. 'math.ceil(nbits/8)*8' bits. In + other words: nbits is rounded up to whole bytes. + + >>> p = getprime(8) + >>> is_prime(p-1) + 0 + >>> is_prime(p) + 1 + >>> is_prime(p+1) + 0 + """ + + while True: + integer = read_random_int(nbits) + + # Make sure it's odd + integer |= 1 + + # Test for primeness + if is_prime(integer): break + + # Retry if not prime + + return integer + +def are_relatively_prime(a, b): + """Returns True if a and b are relatively prime, and False if they + are not. + + >>> are_relatively_prime(2, 3) + 1 + >>> are_relatively_prime(2, 4) + 0 + """ + + d = gcd(a, b) + return (d == 1) + +def find_p_q(nbits): + """Returns a tuple of two different primes of nbits bits""" + pbits = nbits + (nbits/16) #Make sure that p and q aren't too close + qbits = nbits - (nbits/16) #or the factoring programs can factor n + p = getprime(pbits) + while True: + q = getprime(qbits) + #Make sure p and q are different. + if not q == p: break + return (p, q) + +def extended_gcd(a, b): + """Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb + """ + # r = gcd(a,b) i = multiplicitive inverse of a mod b + # or j = multiplicitive inverse of b mod a + # Neg return values for i or j are made positive mod b or a respectively + # Iterateive Version is faster and uses much less stack space + x = 0 + y = 1 + lx = 1 + ly = 0 + oa = a #Remember original a/b to remove + ob = b #negative values from return results + while b != 0: + q = long(a/b) + (a, b) = (b, a % b) + (x, lx) = ((lx - (q * x)),x) + (y, ly) = ((ly - (q * y)),y) + if (lx < 0): lx += ob #If neg wrap modulo orignal b + if (ly < 0): ly += oa #If neg wrap modulo orignal a + return (a, lx, ly) #Return only positive values + +# Main function: calculate encryption and decryption keys +def calculate_keys(p, q, nbits): + """Calculates an encryption and a decryption key for p and q, and + returns them as a tuple (e, d)""" + + n = p * q + phi_n = (p-1) * (q-1) + + while True: + # Make sure e has enough bits so we ensure "wrapping" through + # modulo n + e = max(65537,getprime(nbits/4)) + if are_relatively_prime(e, n) and are_relatively_prime(e, phi_n): break + + (d, i, j) = extended_gcd(e, phi_n) + + if not d == 1: + raise Exception("e (%d) and phi_n (%d) are not relatively prime" % (e, phi_n)) + if (i < 0): + raise Exception("New extended_gcd shouldn't return negative values") + if not (e * i) % phi_n == 1: + raise Exception("e (%d) and i (%d) are not mult. inv. modulo phi_n (%d)" % (e, i, phi_n)) + + return (e, i) + + +def gen_keys(nbits): + """Generate RSA keys of nbits bits. Returns (p, q, e, d). + + Note: this can take a long time, depending on the key size. + """ + + (p, q) = find_p_q(nbits) + (e, d) = calculate_keys(p, q, nbits) + + return (p, q, e, d) + +def newkeys(nbits): + """Generates public and private keys, and returns them as (pub, + priv). + + The public key consists of a dict {e: ..., , n: ....). The private + key consists of a dict {d: ...., p: ...., q: ....). + """ + nbits = max(9,nbits) # Don't let nbits go below 9 bits + (p, q, e, d) = gen_keys(nbits) + + return ( {'e': e, 'n': p*q}, {'d': d, 'p': p, 'q': q} ) + +def encrypt_int(message, ekey, n): + """Encrypts a message using encryption key 'ekey', working modulo n""" + + if type(message) is types.IntType: + message = long(message) + + if not type(message) is types.LongType: + raise TypeError("You must pass a long or int") + + if message < 0 or message > n: + raise OverflowError("The message is too long") + + #Note: Bit exponents start at zero (bit counts start at 1) this is correct + safebit = bit_size(n) - 2 #compute safe bit (MSB - 1) + message += (1 << safebit) #add safebit to ensure folding + + return pow(message, ekey, n) + +def decrypt_int(cyphertext, dkey, n): + """Decrypts a cypher text using the decryption key 'dkey', working + modulo n""" + + message = pow(cyphertext, dkey, n) + + safebit = bit_size(n) - 2 #compute safe bit (MSB - 1) + message -= (1 << safebit) #remove safebit before decode + + return message + +def encode64chops(chops): + """base64encodes chops and combines them into a ',' delimited string""" + + chips = [] #chips are character chops + + for value in chops: + chips.append(int2str64(value)) + + #delimit chops with comma + encoded = ','.join(chips) + + return encoded + +def decode64chops(string): + """base64decodes and makes a ',' delimited string into chops""" + + chips = string.split(',') #split chops at commas + + chops = [] + + for string in chips: #make char chops (chips) into chops + chops.append(str642int(string)) + + return chops + +def chopstring(message, key, n, funcref): + """Chops the 'message' into integers that fit into n, + leaving room for a safebit to be added to ensure that all + messages fold during exponentiation. The MSB of the number n + is not independant modulo n (setting it could cause overflow), so + use the next lower bit for the safebit. Therefore reserve 2-bits + in the number n for non-data bits. Calls specified encryption + function for each chop. + + Used by 'encrypt' and 'sign'. + """ + + msglen = len(message) + mbits = msglen * 8 + #Set aside 2-bits so setting of safebit won't overflow modulo n. + nbits = bit_size(n) - 2 # leave room for safebit + nbytes = nbits / 8 + blocks = msglen / nbytes + + if msglen % nbytes > 0: + blocks += 1 + + cypher = [] + + for bindex in range(blocks): + offset = bindex * nbytes + block = message[offset:offset+nbytes] + value = bytes2int(block) + cypher.append(funcref(value, key, n)) + + return encode64chops(cypher) #Encode encrypted ints to base64 strings + +def gluechops(string, key, n, funcref): + """Glues chops back together into a string. calls + funcref(integer, key, n) for each chop. + + Used by 'decrypt' and 'verify'. + """ + message = "" + + chops = decode64chops(string) #Decode base64 strings into integer chops + + for cpart in chops: + mpart = funcref(cpart, key, n) #Decrypt each chop + message += int2bytes(mpart) #Combine decrypted strings into a msg + + return message + +def encrypt(message, key): + """Encrypts a string 'message' with the public key 'key'""" + if 'n' not in key: + raise Exception("You must use the public key with encrypt") + + return chopstring(message, key['e'], key['n'], encrypt_int) + +def sign(message, key): + """Signs a string 'message' with the private key 'key'""" + if 'p' not in key: + raise Exception("You must use the private key with sign") + + return chopstring(message, key['d'], key['p']*key['q'], encrypt_int) + +def decrypt(cypher, key): + """Decrypts a string 'cypher' with the private key 'key'""" + if 'p' not in key: + raise Exception("You must use the private key with decrypt") + + return gluechops(cypher, key['d'], key['p']*key['q'], decrypt_int) + +def verify(cypher, key): + """Verifies a string 'cypher' with the public key 'key'""" + if 'n' not in key: + raise Exception("You must use the public key with verify") + + return gluechops(cypher, key['e'], key['n'], decrypt_int) + +# Do doctest if we're not imported +if __name__ == "__main__": + import doctest + doctest.testmod() + +__all__ = ["newkeys", "encrypt", "decrypt", "sign", "verify"] + diff --git a/src/lib/rsa/asn1.py b/src/lib/rsa/asn1.py new file mode 100644 index 00000000..6eb6da53 --- /dev/null +++ b/src/lib/rsa/asn1.py @@ -0,0 +1,51 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''ASN.1 definitions. + +Not all ASN.1-handling code use these definitions, but when it does, they should be here. +''' + +from pyasn1.type import univ, namedtype, tag + +class PubKeyHeader(univ.Sequence): + componentType = namedtype.NamedTypes( + namedtype.NamedType('oid', univ.ObjectIdentifier()), + namedtype.NamedType('parameters', univ.Null()), + ) + +class OpenSSLPubKey(univ.Sequence): + componentType = namedtype.NamedTypes( + namedtype.NamedType('header', PubKeyHeader()), + + # This little hack (the implicit tag) allows us to get a Bit String as Octet String + namedtype.NamedType('key', univ.OctetString().subtype( + implicitTag=tag.Tag(tagClass=0, tagFormat=0, tagId=3))), + ) + + +class AsnPubKey(univ.Sequence): + '''ASN.1 contents of DER encoded public key: + + RSAPublicKey ::= SEQUENCE { + modulus INTEGER, -- n + publicExponent INTEGER, -- e + ''' + + componentType = namedtype.NamedTypes( + namedtype.NamedType('modulus', univ.Integer()), + namedtype.NamedType('publicExponent', univ.Integer()), + ) diff --git a/src/lib/rsa/bigfile.py b/src/lib/rsa/bigfile.py new file mode 100644 index 00000000..516cf56b --- /dev/null +++ b/src/lib/rsa/bigfile.py @@ -0,0 +1,87 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Large file support + + - break a file into smaller blocks, and encrypt them, and store the + encrypted blocks in another file. + + - take such an encrypted files, decrypt its blocks, and reconstruct the + original file. + +The encrypted file format is as follows, where || denotes byte concatenation: + + FILE := VERSION || BLOCK || BLOCK ... + + BLOCK := LENGTH || DATA + + LENGTH := varint-encoded length of the subsequent data. Varint comes from + Google Protobuf, and encodes an integer into a variable number of bytes. + Each byte uses the 7 lowest bits to encode the value. The highest bit set + to 1 indicates the next byte is also part of the varint. The last byte will + have this bit set to 0. + +This file format is called the VARBLOCK format, in line with the varint format +used to denote the block sizes. + +''' + +from rsa import key, common, pkcs1, varblock +from rsa._compat import byte + +def encrypt_bigfile(infile, outfile, pub_key): + '''Encrypts a file, writing it to 'outfile' in VARBLOCK format. + + :param infile: file-like object to read the cleartext from + :param outfile: file-like object to write the crypto in VARBLOCK format to + :param pub_key: :py:class:`rsa.PublicKey` to encrypt with + + ''' + + if not isinstance(pub_key, key.PublicKey): + raise TypeError('Public key required, but got %r' % pub_key) + + key_bytes = common.bit_size(pub_key.n) // 8 + blocksize = key_bytes - 11 # keep space for PKCS#1 padding + + # Write the version number to the VARBLOCK file + outfile.write(byte(varblock.VARBLOCK_VERSION)) + + # Encrypt and write each block + for block in varblock.yield_fixedblocks(infile, blocksize): + crypto = pkcs1.encrypt(block, pub_key) + + varblock.write_varint(outfile, len(crypto)) + outfile.write(crypto) + +def decrypt_bigfile(infile, outfile, priv_key): + '''Decrypts an encrypted VARBLOCK file, writing it to 'outfile' + + :param infile: file-like object to read the crypto in VARBLOCK format from + :param outfile: file-like object to write the cleartext to + :param priv_key: :py:class:`rsa.PrivateKey` to decrypt with + + ''' + + if not isinstance(priv_key, key.PrivateKey): + raise TypeError('Private key required, but got %r' % priv_key) + + for block in varblock.yield_varblocks(infile): + cleartext = pkcs1.decrypt(block, priv_key) + outfile.write(cleartext) + +__all__ = ['encrypt_bigfile', 'decrypt_bigfile'] + diff --git a/src/lib/rsa/cli.py b/src/lib/rsa/cli.py new file mode 100644 index 00000000..527cc497 --- /dev/null +++ b/src/lib/rsa/cli.py @@ -0,0 +1,379 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Commandline scripts. + +These scripts are called by the executables defined in setup.py. +''' + +from __future__ import with_statement, print_function + +import abc +import sys +from optparse import OptionParser + +import rsa +import rsa.bigfile +import rsa.pkcs1 + +HASH_METHODS = sorted(rsa.pkcs1.HASH_METHODS.keys()) + +def keygen(): + '''Key generator.''' + + # Parse the CLI options + parser = OptionParser(usage='usage: %prog [options] keysize', + description='Generates a new RSA keypair of "keysize" bits.') + + parser.add_option('--pubout', type='string', + help='Output filename for the public key. The public key is ' + 'not saved if this option is not present. You can use ' + 'pyrsa-priv2pub to create the public key file later.') + + parser.add_option('-o', '--out', type='string', + help='Output filename for the private key. The key is ' + 'written to stdout if this option is not present.') + + parser.add_option('--form', + help='key format of the private and public keys - default PEM', + choices=('PEM', 'DER'), default='PEM') + + (cli, cli_args) = parser.parse_args(sys.argv[1:]) + + if len(cli_args) != 1: + parser.print_help() + raise SystemExit(1) + + try: + keysize = int(cli_args[0]) + except ValueError: + parser.print_help() + print('Not a valid number: %s' % cli_args[0], file=sys.stderr) + raise SystemExit(1) + + print('Generating %i-bit key' % keysize, file=sys.stderr) + (pub_key, priv_key) = rsa.newkeys(keysize) + + + # Save public key + if cli.pubout: + print('Writing public key to %s' % cli.pubout, file=sys.stderr) + data = pub_key.save_pkcs1(format=cli.form) + with open(cli.pubout, 'wb') as outfile: + outfile.write(data) + + # Save private key + data = priv_key.save_pkcs1(format=cli.form) + + if cli.out: + print('Writing private key to %s' % cli.out, file=sys.stderr) + with open(cli.out, 'wb') as outfile: + outfile.write(data) + else: + print('Writing private key to stdout', file=sys.stderr) + sys.stdout.write(data) + + +class CryptoOperation(object): + '''CLI callable that operates with input, output, and a key.''' + + __metaclass__ = abc.ABCMeta + + keyname = 'public' # or 'private' + usage = 'usage: %%prog [options] %(keyname)s_key' + description = None + operation = 'decrypt' + operation_past = 'decrypted' + operation_progressive = 'decrypting' + input_help = 'Name of the file to %(operation)s. Reads from stdin if ' \ + 'not specified.' + output_help = 'Name of the file to write the %(operation_past)s file ' \ + 'to. Written to stdout if this option is not present.' + expected_cli_args = 1 + has_output = True + + key_class = rsa.PublicKey + + def __init__(self): + self.usage = self.usage % self.__class__.__dict__ + self.input_help = self.input_help % self.__class__.__dict__ + self.output_help = self.output_help % self.__class__.__dict__ + + @abc.abstractmethod + def perform_operation(self, indata, key, cli_args=None): + '''Performs the program's operation. + + Implement in a subclass. + + :returns: the data to write to the output. + ''' + + def __call__(self): + '''Runs the program.''' + + (cli, cli_args) = self.parse_cli() + + key = self.read_key(cli_args[0], cli.keyform) + + indata = self.read_infile(cli.input) + + print(self.operation_progressive.title(), file=sys.stderr) + outdata = self.perform_operation(indata, key, cli_args) + + if self.has_output: + self.write_outfile(outdata, cli.output) + + def parse_cli(self): + '''Parse the CLI options + + :returns: (cli_opts, cli_args) + ''' + + parser = OptionParser(usage=self.usage, description=self.description) + + parser.add_option('-i', '--input', type='string', help=self.input_help) + + if self.has_output: + parser.add_option('-o', '--output', type='string', help=self.output_help) + + parser.add_option('--keyform', + help='Key format of the %s key - default PEM' % self.keyname, + choices=('PEM', 'DER'), default='PEM') + + (cli, cli_args) = parser.parse_args(sys.argv[1:]) + + if len(cli_args) != self.expected_cli_args: + parser.print_help() + raise SystemExit(1) + + return (cli, cli_args) + + def read_key(self, filename, keyform): + '''Reads a public or private key.''' + + print('Reading %s key from %s' % (self.keyname, filename), file=sys.stderr) + with open(filename, 'rb') as keyfile: + keydata = keyfile.read() + + return self.key_class.load_pkcs1(keydata, keyform) + + def read_infile(self, inname): + '''Read the input file''' + + if inname: + print('Reading input from %s' % inname, file=sys.stderr) + with open(inname, 'rb') as infile: + return infile.read() + + print('Reading input from stdin', file=sys.stderr) + return sys.stdin.read() + + def write_outfile(self, outdata, outname): + '''Write the output file''' + + if outname: + print('Writing output to %s' % outname, file=sys.stderr) + with open(outname, 'wb') as outfile: + outfile.write(outdata) + else: + print('Writing output to stdout', file=sys.stderr) + sys.stdout.write(outdata) + +class EncryptOperation(CryptoOperation): + '''Encrypts a file.''' + + keyname = 'public' + description = ('Encrypts a file. The file must be shorter than the key ' + 'length in order to be encrypted. For larger files, use the ' + 'pyrsa-encrypt-bigfile command.') + operation = 'encrypt' + operation_past = 'encrypted' + operation_progressive = 'encrypting' + + + def perform_operation(self, indata, pub_key, cli_args=None): + '''Encrypts files.''' + + return rsa.encrypt(indata, pub_key) + +class DecryptOperation(CryptoOperation): + '''Decrypts a file.''' + + keyname = 'private' + description = ('Decrypts a file. The original file must be shorter than ' + 'the key length in order to have been encrypted. For larger ' + 'files, use the pyrsa-decrypt-bigfile command.') + operation = 'decrypt' + operation_past = 'decrypted' + operation_progressive = 'decrypting' + key_class = rsa.PrivateKey + + def perform_operation(self, indata, priv_key, cli_args=None): + '''Decrypts files.''' + + return rsa.decrypt(indata, priv_key) + +class SignOperation(CryptoOperation): + '''Signs a file.''' + + keyname = 'private' + usage = 'usage: %%prog [options] private_key hash_method' + description = ('Signs a file, outputs the signature. Choose the hash ' + 'method from %s' % ', '.join(HASH_METHODS)) + operation = 'sign' + operation_past = 'signature' + operation_progressive = 'Signing' + key_class = rsa.PrivateKey + expected_cli_args = 2 + + output_help = ('Name of the file to write the signature to. Written ' + 'to stdout if this option is not present.') + + def perform_operation(self, indata, priv_key, cli_args): + '''Decrypts files.''' + + hash_method = cli_args[1] + if hash_method not in HASH_METHODS: + raise SystemExit('Invalid hash method, choose one of %s' % + ', '.join(HASH_METHODS)) + + return rsa.sign(indata, priv_key, hash_method) + +class VerifyOperation(CryptoOperation): + '''Verify a signature.''' + + keyname = 'public' + usage = 'usage: %%prog [options] public_key signature_file' + description = ('Verifies a signature, exits with status 0 upon success, ' + 'prints an error message and exits with status 1 upon error.') + operation = 'verify' + operation_past = 'verified' + operation_progressive = 'Verifying' + key_class = rsa.PublicKey + expected_cli_args = 2 + has_output = False + + def perform_operation(self, indata, pub_key, cli_args): + '''Decrypts files.''' + + signature_file = cli_args[1] + + with open(signature_file, 'rb') as sigfile: + signature = sigfile.read() + + try: + rsa.verify(indata, signature, pub_key) + except rsa.VerificationError: + raise SystemExit('Verification failed.') + + print('Verification OK', file=sys.stderr) + + +class BigfileOperation(CryptoOperation): + '''CryptoOperation that doesn't read the entire file into memory.''' + + def __init__(self): + CryptoOperation.__init__(self) + + self.file_objects = [] + + def __del__(self): + '''Closes any open file handles.''' + + for fobj in self.file_objects: + fobj.close() + + def __call__(self): + '''Runs the program.''' + + (cli, cli_args) = self.parse_cli() + + key = self.read_key(cli_args[0], cli.keyform) + + # Get the file handles + infile = self.get_infile(cli.input) + outfile = self.get_outfile(cli.output) + + # Call the operation + print(self.operation_progressive.title(), file=sys.stderr) + self.perform_operation(infile, outfile, key, cli_args) + + def get_infile(self, inname): + '''Returns the input file object''' + + if inname: + print('Reading input from %s' % inname, file=sys.stderr) + fobj = open(inname, 'rb') + self.file_objects.append(fobj) + else: + print('Reading input from stdin', file=sys.stderr) + fobj = sys.stdin + + return fobj + + def get_outfile(self, outname): + '''Returns the output file object''' + + if outname: + print('Will write output to %s' % outname, file=sys.stderr) + fobj = open(outname, 'wb') + self.file_objects.append(fobj) + else: + print('Will write output to stdout', file=sys.stderr) + fobj = sys.stdout + + return fobj + +class EncryptBigfileOperation(BigfileOperation): + '''Encrypts a file to VARBLOCK format.''' + + keyname = 'public' + description = ('Encrypts a file to an encrypted VARBLOCK file. The file ' + 'can be larger than the key length, but the output file is only ' + 'compatible with Python-RSA.') + operation = 'encrypt' + operation_past = 'encrypted' + operation_progressive = 'encrypting' + + def perform_operation(self, infile, outfile, pub_key, cli_args=None): + '''Encrypts files to VARBLOCK.''' + + return rsa.bigfile.encrypt_bigfile(infile, outfile, pub_key) + +class DecryptBigfileOperation(BigfileOperation): + '''Decrypts a file in VARBLOCK format.''' + + keyname = 'private' + description = ('Decrypts an encrypted VARBLOCK file that was encrypted ' + 'with pyrsa-encrypt-bigfile') + operation = 'decrypt' + operation_past = 'decrypted' + operation_progressive = 'decrypting' + key_class = rsa.PrivateKey + + def perform_operation(self, infile, outfile, priv_key, cli_args=None): + '''Decrypts a VARBLOCK file.''' + + return rsa.bigfile.decrypt_bigfile(infile, outfile, priv_key) + + +encrypt = EncryptOperation() +decrypt = DecryptOperation() +sign = SignOperation() +verify = VerifyOperation() +encrypt_bigfile = EncryptBigfileOperation() +decrypt_bigfile = DecryptBigfileOperation() + diff --git a/src/lib/rsa/common.py b/src/lib/rsa/common.py new file mode 100644 index 00000000..39feb8c2 --- /dev/null +++ b/src/lib/rsa/common.py @@ -0,0 +1,185 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Common functionality shared by several modules.''' + + +def bit_size(num): + ''' + Number of bits needed to represent a integer excluding any prefix + 0 bits. + + As per definition from http://wiki.python.org/moin/BitManipulation and + to match the behavior of the Python 3 API. + + Usage:: + + >>> bit_size(1023) + 10 + >>> bit_size(1024) + 11 + >>> bit_size(1025) + 11 + + :param num: + Integer value. If num is 0, returns 0. Only the absolute value of the + number is considered. Therefore, signed integers will be abs(num) + before the number's bit length is determined. + :returns: + Returns the number of bits in the integer. + ''' + if num == 0: + return 0 + if num < 0: + num = -num + + # Make sure this is an int and not a float. + num & 1 + + hex_num = "%x" % num + return ((len(hex_num) - 1) * 4) + { + '0':0, '1':1, '2':2, '3':2, + '4':3, '5':3, '6':3, '7':3, + '8':4, '9':4, 'a':4, 'b':4, + 'c':4, 'd':4, 'e':4, 'f':4, + }[hex_num[0]] + + +def _bit_size(number): + ''' + Returns the number of bits required to hold a specific long number. + ''' + if number < 0: + raise ValueError('Only nonnegative numbers possible: %s' % number) + + if number == 0: + return 0 + + # This works, even with very large numbers. When using math.log(number, 2), + # you'll get rounding errors and it'll fail. + bits = 0 + while number: + bits += 1 + number >>= 1 + + return bits + + +def byte_size(number): + ''' + Returns the number of bytes required to hold a specific long number. + + The number of bytes is rounded up. + + Usage:: + + >>> byte_size(1 << 1023) + 128 + >>> byte_size((1 << 1024) - 1) + 128 + >>> byte_size(1 << 1024) + 129 + + :param number: + An unsigned integer + :returns: + The number of bytes required to hold a specific long number. + ''' + quanta, mod = divmod(bit_size(number), 8) + if mod or number == 0: + quanta += 1 + return quanta + #return int(math.ceil(bit_size(number) / 8.0)) + + +def extended_gcd(a, b): + '''Returns a tuple (r, i, j) such that r = gcd(a, b) = ia + jb + ''' + # r = gcd(a,b) i = multiplicitive inverse of a mod b + # or j = multiplicitive inverse of b mod a + # Neg return values for i or j are made positive mod b or a respectively + # Iterateive Version is faster and uses much less stack space + x = 0 + y = 1 + lx = 1 + ly = 0 + oa = a #Remember original a/b to remove + ob = b #negative values from return results + while b != 0: + q = a // b + (a, b) = (b, a % b) + (x, lx) = ((lx - (q * x)),x) + (y, ly) = ((ly - (q * y)),y) + if (lx < 0): lx += ob #If neg wrap modulo orignal b + if (ly < 0): ly += oa #If neg wrap modulo orignal a + return (a, lx, ly) #Return only positive values + + +def inverse(x, n): + '''Returns x^-1 (mod n) + + >>> inverse(7, 4) + 3 + >>> (inverse(143, 4) * 143) % 4 + 1 + ''' + + (divider, inv, _) = extended_gcd(x, n) + + if divider != 1: + raise ValueError("x (%d) and n (%d) are not relatively prime" % (x, n)) + + return inv + + +def crt(a_values, modulo_values): + '''Chinese Remainder Theorem. + + Calculates x such that x = a[i] (mod m[i]) for each i. + + :param a_values: the a-values of the above equation + :param modulo_values: the m-values of the above equation + :returns: x such that x = a[i] (mod m[i]) for each i + + + >>> crt([2, 3], [3, 5]) + 8 + + >>> crt([2, 3, 2], [3, 5, 7]) + 23 + + >>> crt([2, 3, 0], [7, 11, 15]) + 135 + ''' + + m = 1 + x = 0 + + for modulo in modulo_values: + m *= modulo + + for (m_i, a_i) in zip(modulo_values, a_values): + M_i = m // m_i + inv = inverse(M_i, m_i) + + x = (x + a_i * M_i * inv) % m + + return x + +if __name__ == '__main__': + import doctest + doctest.testmod() + diff --git a/src/lib/rsa/core.py b/src/lib/rsa/core.py new file mode 100644 index 00000000..90dfee8e --- /dev/null +++ b/src/lib/rsa/core.py @@ -0,0 +1,58 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Core mathematical operations. + +This is the actual core RSA implementation, which is only defined +mathematically on integers. +''' + + +from rsa._compat import is_integer + +def assert_int(var, name): + + if is_integer(var): + return + + raise TypeError('%s should be an integer, not %s' % (name, var.__class__)) + +def encrypt_int(message, ekey, n): + '''Encrypts a message using encryption key 'ekey', working modulo n''' + + assert_int(message, 'message') + assert_int(ekey, 'ekey') + assert_int(n, 'n') + + if message < 0: + raise ValueError('Only non-negative numbers are supported') + + if message > n: + raise OverflowError("The message %i is too long for n=%i" % (message, n)) + + return pow(message, ekey, n) + +def decrypt_int(cyphertext, dkey, n): + '''Decrypts a cypher text using the decryption key 'dkey', working + modulo n''' + + assert_int(cyphertext, 'cyphertext') + assert_int(dkey, 'dkey') + assert_int(n, 'n') + + message = pow(cyphertext, dkey, n) + return message + diff --git a/src/lib/rsa/key.py b/src/lib/rsa/key.py new file mode 100644 index 00000000..b6de7b3f --- /dev/null +++ b/src/lib/rsa/key.py @@ -0,0 +1,612 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''RSA key generation code. + +Create new keys with the newkeys() function. It will give you a PublicKey and a +PrivateKey object. + +Loading and saving keys requires the pyasn1 module. This module is imported as +late as possible, such that other functionality will remain working in absence +of pyasn1. + +''' + +import logging +from rsa._compat import b, bytes_type + +import rsa.prime +import rsa.pem +import rsa.common + +log = logging.getLogger(__name__) + + + +class AbstractKey(object): + '''Abstract superclass for private and public keys.''' + + @classmethod + def load_pkcs1(cls, keyfile, format='PEM'): + r'''Loads a key in PKCS#1 DER or PEM format. + + :param keyfile: contents of a DER- or PEM-encoded file that contains + the public key. + :param format: the format of the file to load; 'PEM' or 'DER' + + :return: a PublicKey object + + ''' + + methods = { + 'PEM': cls._load_pkcs1_pem, + 'DER': cls._load_pkcs1_der, + } + + if format not in methods: + formats = ', '.join(sorted(methods.keys())) + raise ValueError('Unsupported format: %r, try one of %s' % (format, + formats)) + + method = methods[format] + return method(keyfile) + + def save_pkcs1(self, format='PEM'): + '''Saves the public key in PKCS#1 DER or PEM format. + + :param format: the format to save; 'PEM' or 'DER' + :returns: the DER- or PEM-encoded public key. + + ''' + + methods = { + 'PEM': self._save_pkcs1_pem, + 'DER': self._save_pkcs1_der, + } + + if format not in methods: + formats = ', '.join(sorted(methods.keys())) + raise ValueError('Unsupported format: %r, try one of %s' % (format, + formats)) + + method = methods[format] + return method() + +class PublicKey(AbstractKey): + '''Represents a public RSA key. + + This key is also known as the 'encryption key'. It contains the 'n' and 'e' + values. + + Supports attributes as well as dictionary-like access. Attribute accesss is + faster, though. + + >>> PublicKey(5, 3) + PublicKey(5, 3) + + >>> key = PublicKey(5, 3) + >>> key.n + 5 + >>> key['n'] + 5 + >>> key.e + 3 + >>> key['e'] + 3 + + ''' + + __slots__ = ('n', 'e') + + def __init__(self, n, e): + self.n = n + self.e = e + + def __getitem__(self, key): + return getattr(self, key) + + def __repr__(self): + return 'PublicKey(%i, %i)' % (self.n, self.e) + + def __eq__(self, other): + if other is None: + return False + + if not isinstance(other, PublicKey): + return False + + return self.n == other.n and self.e == other.e + + def __ne__(self, other): + return not (self == other) + + @classmethod + def _load_pkcs1_der(cls, keyfile): + r'''Loads a key in PKCS#1 DER format. + + @param keyfile: contents of a DER-encoded file that contains the public + key. + @return: a PublicKey object + + First let's construct a DER encoded key: + + >>> import base64 + >>> b64der = 'MAwCBQCNGmYtAgMBAAE=' + >>> der = base64.decodestring(b64der) + + This loads the file: + + >>> PublicKey._load_pkcs1_der(der) + PublicKey(2367317549, 65537) + + ''' + + from pyasn1.codec.der import decoder + from rsa.asn1 import AsnPubKey + + (priv, _) = decoder.decode(keyfile, asn1Spec=AsnPubKey()) + return cls(n=int(priv['modulus']), e=int(priv['publicExponent'])) + + def _save_pkcs1_der(self): + '''Saves the public key in PKCS#1 DER format. + + @returns: the DER-encoded public key. + ''' + + from pyasn1.codec.der import encoder + from rsa.asn1 import AsnPubKey + + # Create the ASN object + asn_key = AsnPubKey() + asn_key.setComponentByName('modulus', self.n) + asn_key.setComponentByName('publicExponent', self.e) + + return encoder.encode(asn_key) + + @classmethod + def _load_pkcs1_pem(cls, keyfile): + '''Loads a PKCS#1 PEM-encoded public key file. + + The contents of the file before the "-----BEGIN RSA PUBLIC KEY-----" and + after the "-----END RSA PUBLIC KEY-----" lines is ignored. + + @param keyfile: contents of a PEM-encoded file that contains the public + key. + @return: a PublicKey object + ''' + + der = rsa.pem.load_pem(keyfile, 'RSA PUBLIC KEY') + return cls._load_pkcs1_der(der) + + def _save_pkcs1_pem(self): + '''Saves a PKCS#1 PEM-encoded public key file. + + @return: contents of a PEM-encoded file that contains the public key. + ''' + + der = self._save_pkcs1_der() + return rsa.pem.save_pem(der, 'RSA PUBLIC KEY') + + @classmethod + def load_pkcs1_openssl_pem(cls, keyfile): + '''Loads a PKCS#1.5 PEM-encoded public key file from OpenSSL. + + These files can be recognised in that they start with BEGIN PUBLIC KEY + rather than BEGIN RSA PUBLIC KEY. + + The contents of the file before the "-----BEGIN PUBLIC KEY-----" and + after the "-----END PUBLIC KEY-----" lines is ignored. + + @param keyfile: contents of a PEM-encoded file that contains the public + key, from OpenSSL. + @return: a PublicKey object + ''' + + der = rsa.pem.load_pem(keyfile, 'PUBLIC KEY') + return cls.load_pkcs1_openssl_der(der) + + @classmethod + def load_pkcs1_openssl_der(cls, keyfile): + '''Loads a PKCS#1 DER-encoded public key file from OpenSSL. + + @param keyfile: contents of a DER-encoded file that contains the public + key, from OpenSSL. + @return: a PublicKey object + ''' + + from rsa.asn1 import OpenSSLPubKey + from pyasn1.codec.der import decoder + from pyasn1.type import univ + + (keyinfo, _) = decoder.decode(keyfile, asn1Spec=OpenSSLPubKey()) + + if keyinfo['header']['oid'] != univ.ObjectIdentifier('1.2.840.113549.1.1.1'): + raise TypeError("This is not a DER-encoded OpenSSL-compatible public key") + + return cls._load_pkcs1_der(keyinfo['key'][1:]) + + + + +class PrivateKey(AbstractKey): + '''Represents a private RSA key. + + This key is also known as the 'decryption key'. It contains the 'n', 'e', + 'd', 'p', 'q' and other values. + + Supports attributes as well as dictionary-like access. Attribute accesss is + faster, though. + + >>> PrivateKey(3247, 65537, 833, 191, 17) + PrivateKey(3247, 65537, 833, 191, 17) + + exp1, exp2 and coef don't have to be given, they will be calculated: + + >>> pk = PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) + >>> pk.exp1 + 55063 + >>> pk.exp2 + 10095 + >>> pk.coef + 50797 + + If you give exp1, exp2 or coef, they will be used as-is: + + >>> pk = PrivateKey(1, 2, 3, 4, 5, 6, 7, 8) + >>> pk.exp1 + 6 + >>> pk.exp2 + 7 + >>> pk.coef + 8 + + ''' + + __slots__ = ('n', 'e', 'd', 'p', 'q', 'exp1', 'exp2', 'coef') + + def __init__(self, n, e, d, p, q, exp1=None, exp2=None, coef=None): + self.n = n + self.e = e + self.d = d + self.p = p + self.q = q + + # Calculate the other values if they aren't supplied + if exp1 is None: + self.exp1 = int(d % (p - 1)) + else: + self.exp1 = exp1 + + if exp1 is None: + self.exp2 = int(d % (q - 1)) + else: + self.exp2 = exp2 + + if coef is None: + self.coef = rsa.common.inverse(q, p) + else: + self.coef = coef + + def __getitem__(self, key): + return getattr(self, key) + + def __repr__(self): + return 'PrivateKey(%(n)i, %(e)i, %(d)i, %(p)i, %(q)i)' % self + + def __eq__(self, other): + if other is None: + return False + + if not isinstance(other, PrivateKey): + return False + + return (self.n == other.n and + self.e == other.e and + self.d == other.d and + self.p == other.p and + self.q == other.q and + self.exp1 == other.exp1 and + self.exp2 == other.exp2 and + self.coef == other.coef) + + def __ne__(self, other): + return not (self == other) + + @classmethod + def _load_pkcs1_der(cls, keyfile): + r'''Loads a key in PKCS#1 DER format. + + @param keyfile: contents of a DER-encoded file that contains the private + key. + @return: a PrivateKey object + + First let's construct a DER encoded key: + + >>> import base64 + >>> b64der = 'MC4CAQACBQDeKYlRAgMBAAECBQDHn4npAgMA/icCAwDfxwIDANcXAgInbwIDAMZt' + >>> der = base64.decodestring(b64der) + + This loads the file: + + >>> PrivateKey._load_pkcs1_der(der) + PrivateKey(3727264081, 65537, 3349121513, 65063, 57287) + + ''' + + from pyasn1.codec.der import decoder + (priv, _) = decoder.decode(keyfile) + + # ASN.1 contents of DER encoded private key: + # + # RSAPrivateKey ::= SEQUENCE { + # version Version, + # modulus INTEGER, -- n + # publicExponent INTEGER, -- e + # privateExponent INTEGER, -- d + # prime1 INTEGER, -- p + # prime2 INTEGER, -- q + # exponent1 INTEGER, -- d mod (p-1) + # exponent2 INTEGER, -- d mod (q-1) + # coefficient INTEGER, -- (inverse of q) mod p + # otherPrimeInfos OtherPrimeInfos OPTIONAL + # } + + if priv[0] != 0: + raise ValueError('Unable to read this file, version %s != 0' % priv[0]) + + as_ints = tuple(int(x) for x in priv[1:9]) + return cls(*as_ints) + + def _save_pkcs1_der(self): + '''Saves the private key in PKCS#1 DER format. + + @returns: the DER-encoded private key. + ''' + + from pyasn1.type import univ, namedtype + from pyasn1.codec.der import encoder + + class AsnPrivKey(univ.Sequence): + componentType = namedtype.NamedTypes( + namedtype.NamedType('version', univ.Integer()), + namedtype.NamedType('modulus', univ.Integer()), + namedtype.NamedType('publicExponent', univ.Integer()), + namedtype.NamedType('privateExponent', univ.Integer()), + namedtype.NamedType('prime1', univ.Integer()), + namedtype.NamedType('prime2', univ.Integer()), + namedtype.NamedType('exponent1', univ.Integer()), + namedtype.NamedType('exponent2', univ.Integer()), + namedtype.NamedType('coefficient', univ.Integer()), + ) + + # Create the ASN object + asn_key = AsnPrivKey() + asn_key.setComponentByName('version', 0) + asn_key.setComponentByName('modulus', self.n) + asn_key.setComponentByName('publicExponent', self.e) + asn_key.setComponentByName('privateExponent', self.d) + asn_key.setComponentByName('prime1', self.p) + asn_key.setComponentByName('prime2', self.q) + asn_key.setComponentByName('exponent1', self.exp1) + asn_key.setComponentByName('exponent2', self.exp2) + asn_key.setComponentByName('coefficient', self.coef) + + return encoder.encode(asn_key) + + @classmethod + def _load_pkcs1_pem(cls, keyfile): + '''Loads a PKCS#1 PEM-encoded private key file. + + The contents of the file before the "-----BEGIN RSA PRIVATE KEY-----" and + after the "-----END RSA PRIVATE KEY-----" lines is ignored. + + @param keyfile: contents of a PEM-encoded file that contains the private + key. + @return: a PrivateKey object + ''' + + der = rsa.pem.load_pem(keyfile, b('RSA PRIVATE KEY')) + return cls._load_pkcs1_der(der) + + def _save_pkcs1_pem(self): + '''Saves a PKCS#1 PEM-encoded private key file. + + @return: contents of a PEM-encoded file that contains the private key. + ''' + + der = self._save_pkcs1_der() + return rsa.pem.save_pem(der, b('RSA PRIVATE KEY')) + +def find_p_q(nbits, getprime_func=rsa.prime.getprime, accurate=True): + ''''Returns a tuple of two different primes of nbits bits each. + + The resulting p * q has exacty 2 * nbits bits, and the returned p and q + will not be equal. + + :param nbits: the number of bits in each of p and q. + :param getprime_func: the getprime function, defaults to + :py:func:`rsa.prime.getprime`. + + *Introduced in Python-RSA 3.1* + + :param accurate: whether to enable accurate mode or not. + :returns: (p, q), where p > q + + >>> (p, q) = find_p_q(128) + >>> from rsa import common + >>> common.bit_size(p * q) + 256 + + When not in accurate mode, the number of bits can be slightly less + + >>> (p, q) = find_p_q(128, accurate=False) + >>> from rsa import common + >>> common.bit_size(p * q) <= 256 + True + >>> common.bit_size(p * q) > 240 + True + + ''' + + total_bits = nbits * 2 + + # Make sure that p and q aren't too close or the factoring programs can + # factor n. + shift = nbits // 16 + pbits = nbits + shift + qbits = nbits - shift + + # Choose the two initial primes + log.debug('find_p_q(%i): Finding p', nbits) + p = getprime_func(pbits) + log.debug('find_p_q(%i): Finding q', nbits) + q = getprime_func(qbits) + + def is_acceptable(p, q): + '''Returns True iff p and q are acceptable: + + - p and q differ + - (p * q) has the right nr of bits (when accurate=True) + ''' + + if p == q: + return False + + if not accurate: + return True + + # Make sure we have just the right amount of bits + found_size = rsa.common.bit_size(p * q) + return total_bits == found_size + + # Keep choosing other primes until they match our requirements. + change_p = False + while not is_acceptable(p, q): + # Change p on one iteration and q on the other + if change_p: + p = getprime_func(pbits) + else: + q = getprime_func(qbits) + + change_p = not change_p + + # We want p > q as described on + # http://www.di-mgt.com.au/rsa_alg.html#crt + return (max(p, q), min(p, q)) + +def calculate_keys(p, q, nbits): + '''Calculates an encryption and a decryption key given p and q, and + returns them as a tuple (e, d) + + ''' + + phi_n = (p - 1) * (q - 1) + + # A very common choice for e is 65537 + e = 65537 + + try: + d = rsa.common.inverse(e, phi_n) + except ValueError: + raise ValueError("e (%d) and phi_n (%d) are not relatively prime" % + (e, phi_n)) + + if (e * d) % phi_n != 1: + raise ValueError("e (%d) and d (%d) are not mult. inv. modulo " + "phi_n (%d)" % (e, d, phi_n)) + + return (e, d) + +def gen_keys(nbits, getprime_func, accurate=True): + '''Generate RSA keys of nbits bits. Returns (p, q, e, d). + + Note: this can take a long time, depending on the key size. + + :param nbits: the total number of bits in ``p`` and ``q``. Both ``p`` and + ``q`` will use ``nbits/2`` bits. + :param getprime_func: either :py:func:`rsa.prime.getprime` or a function + with similar signature. + ''' + + (p, q) = find_p_q(nbits // 2, getprime_func, accurate) + (e, d) = calculate_keys(p, q, nbits // 2) + + return (p, q, e, d) + +def newkeys(nbits, accurate=True, poolsize=1): + '''Generates public and private keys, and returns them as (pub, priv). + + The public key is also known as the 'encryption key', and is a + :py:class:`rsa.PublicKey` object. The private key is also known as the + 'decryption key' and is a :py:class:`rsa.PrivateKey` object. + + :param nbits: the number of bits required to store ``n = p*q``. + :param accurate: when True, ``n`` will have exactly the number of bits you + asked for. However, this makes key generation much slower. When False, + `n`` may have slightly less bits. + :param poolsize: the number of processes to use to generate the prime + numbers. If set to a number > 1, a parallel algorithm will be used. + This requires Python 2.6 or newer. + + :returns: a tuple (:py:class:`rsa.PublicKey`, :py:class:`rsa.PrivateKey`) + + The ``poolsize`` parameter was added in *Python-RSA 3.1* and requires + Python 2.6 or newer. + + ''' + + if nbits < 16: + raise ValueError('Key too small') + + if poolsize < 1: + raise ValueError('Pool size (%i) should be >= 1' % poolsize) + + # Determine which getprime function to use + if poolsize > 1: + from rsa import parallel + import functools + + getprime_func = functools.partial(parallel.getprime, poolsize=poolsize) + else: getprime_func = rsa.prime.getprime + + # Generate the key components + (p, q, e, d) = gen_keys(nbits, getprime_func) + + # Create the key objects + n = p * q + + return ( + PublicKey(n, e), + PrivateKey(n, e, d, p, q) + ) + +__all__ = ['PublicKey', 'PrivateKey', 'newkeys'] + +if __name__ == '__main__': + import doctest + + try: + for count in range(100): + (failures, tests) = doctest.testmod() + if failures: + break + + if (count and count % 10 == 0) or count == 1: + print('%i times' % count) + except KeyboardInterrupt: + print('Aborted') + else: + print('Doctests done') diff --git a/src/lib/rsa/parallel.py b/src/lib/rsa/parallel.py new file mode 100644 index 00000000..e5034ac7 --- /dev/null +++ b/src/lib/rsa/parallel.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Functions for parallel computation on multiple cores. + +Introduced in Python-RSA 3.1. + +.. note:: + + Requires Python 2.6 or newer. + +''' + +from __future__ import print_function + +import multiprocessing as mp + +import rsa.prime +import rsa.randnum + +def _find_prime(nbits, pipe): + while True: + integer = rsa.randnum.read_random_int(nbits) + + # Make sure it's odd + integer |= 1 + + # Test for primeness + if rsa.prime.is_prime(integer): + pipe.send(integer) + return + +def getprime(nbits, poolsize): + '''Returns a prime number that can be stored in 'nbits' bits. + + Works in multiple threads at the same time. + + >>> p = getprime(128, 3) + >>> rsa.prime.is_prime(p-1) + False + >>> rsa.prime.is_prime(p) + True + >>> rsa.prime.is_prime(p+1) + False + + >>> from rsa import common + >>> common.bit_size(p) == 128 + True + + ''' + + (pipe_recv, pipe_send) = mp.Pipe(duplex=False) + + # Create processes + procs = [mp.Process(target=_find_prime, args=(nbits, pipe_send)) + for _ in range(poolsize)] + [p.start() for p in procs] + + result = pipe_recv.recv() + + [p.terminate() for p in procs] + + return result + +__all__ = ['getprime'] + + +if __name__ == '__main__': + print('Running doctests 1000x or until failure') + import doctest + + for count in range(100): + (failures, tests) = doctest.testmod() + if failures: + break + + if count and count % 10 == 0: + print('%i times' % count) + + print('Doctests done') + diff --git a/src/lib/rsa/pem.py b/src/lib/rsa/pem.py new file mode 100644 index 00000000..b1c3a0ed --- /dev/null +++ b/src/lib/rsa/pem.py @@ -0,0 +1,120 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Functions that load and write PEM-encoded files.''' + +import base64 +from rsa._compat import b, is_bytes + +def _markers(pem_marker): + ''' + Returns the start and end PEM markers + ''' + + if is_bytes(pem_marker): + pem_marker = pem_marker.decode('utf-8') + + return (b('-----BEGIN %s-----' % pem_marker), + b('-----END %s-----' % pem_marker)) + +def load_pem(contents, pem_marker): + '''Loads a PEM file. + + @param contents: the contents of the file to interpret + @param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY' + when your file has '-----BEGIN RSA PRIVATE KEY-----' and + '-----END RSA PRIVATE KEY-----' markers. + + @return the base64-decoded content between the start and end markers. + + @raise ValueError: when the content is invalid, for example when the start + marker cannot be found. + + ''' + + (pem_start, pem_end) = _markers(pem_marker) + + pem_lines = [] + in_pem_part = False + + for line in contents.splitlines(): + line = line.strip() + + # Skip empty lines + if not line: + continue + + # Handle start marker + if line == pem_start: + if in_pem_part: + raise ValueError('Seen start marker "%s" twice' % pem_start) + + in_pem_part = True + continue + + # Skip stuff before first marker + if not in_pem_part: + continue + + # Handle end marker + if in_pem_part and line == pem_end: + in_pem_part = False + break + + # Load fields + if b(':') in line: + continue + + pem_lines.append(line) + + # Do some sanity checks + if not pem_lines: + raise ValueError('No PEM start marker "%s" found' % pem_start) + + if in_pem_part: + raise ValueError('No PEM end marker "%s" found' % pem_end) + + # Base64-decode the contents + pem = b('').join(pem_lines) + return base64.decodestring(pem) + + +def save_pem(contents, pem_marker): + '''Saves a PEM file. + + @param contents: the contents to encode in PEM format + @param pem_marker: the marker of the PEM content, such as 'RSA PRIVATE KEY' + when your file has '-----BEGIN RSA PRIVATE KEY-----' and + '-----END RSA PRIVATE KEY-----' markers. + + @return the base64-encoded content between the start and end markers. + + ''' + + (pem_start, pem_end) = _markers(pem_marker) + + b64 = base64.encodestring(contents).replace(b('\n'), b('')) + pem_lines = [pem_start] + + for block_start in range(0, len(b64), 64): + block = b64[block_start:block_start + 64] + pem_lines.append(block) + + pem_lines.append(pem_end) + pem_lines.append(b('')) + + return b('\n').join(pem_lines) + diff --git a/src/lib/rsa/pkcs1.py b/src/lib/rsa/pkcs1.py new file mode 100644 index 00000000..15e4cf63 --- /dev/null +++ b/src/lib/rsa/pkcs1.py @@ -0,0 +1,391 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Functions for PKCS#1 version 1.5 encryption and signing + +This module implements certain functionality from PKCS#1 version 1.5. For a +very clear example, read http://www.di-mgt.com.au/rsa_alg.html#pkcs1schemes + +At least 8 bytes of random padding is used when encrypting a message. This makes +these methods much more secure than the ones in the ``rsa`` module. + +WARNING: this module leaks information when decryption or verification fails. +The exceptions that are raised contain the Python traceback information, which +can be used to deduce where in the process the failure occurred. DO NOT PASS +SUCH INFORMATION to your users. +''' + +import hashlib +import os + +from rsa._compat import b +from rsa import common, transform, core, varblock + +# ASN.1 codes that describe the hash algorithm used. +HASH_ASN1 = { + 'MD5': b('\x30\x20\x30\x0c\x06\x08\x2a\x86\x48\x86\xf7\x0d\x02\x05\x05\x00\x04\x10'), + 'SHA-1': b('\x30\x21\x30\x09\x06\x05\x2b\x0e\x03\x02\x1a\x05\x00\x04\x14'), + 'SHA-256': b('\x30\x31\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x01\x05\x00\x04\x20'), + 'SHA-384': b('\x30\x41\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x02\x05\x00\x04\x30'), + 'SHA-512': b('\x30\x51\x30\x0d\x06\x09\x60\x86\x48\x01\x65\x03\x04\x02\x03\x05\x00\x04\x40'), +} + +HASH_METHODS = { + 'MD5': hashlib.md5, + 'SHA-1': hashlib.sha1, + 'SHA-256': hashlib.sha256, + 'SHA-384': hashlib.sha384, + 'SHA-512': hashlib.sha512, +} + +class CryptoError(Exception): + '''Base class for all exceptions in this module.''' + +class DecryptionError(CryptoError): + '''Raised when decryption fails.''' + +class VerificationError(CryptoError): + '''Raised when verification fails.''' + +def _pad_for_encryption(message, target_length): + r'''Pads the message for encryption, returning the padded message. + + :return: 00 02 RANDOM_DATA 00 MESSAGE + + >>> block = _pad_for_encryption('hello', 16) + >>> len(block) + 16 + >>> block[0:2] + '\x00\x02' + >>> block[-6:] + '\x00hello' + + ''' + + max_msglength = target_length - 11 + msglength = len(message) + + if msglength > max_msglength: + raise OverflowError('%i bytes needed for message, but there is only' + ' space for %i' % (msglength, max_msglength)) + + # Get random padding + padding = b('') + padding_length = target_length - msglength - 3 + + # We remove 0-bytes, so we'll end up with less padding than we've asked for, + # so keep adding data until we're at the correct length. + while len(padding) < padding_length: + needed_bytes = padding_length - len(padding) + + # Always read at least 8 bytes more than we need, and trim off the rest + # after removing the 0-bytes. This increases the chance of getting + # enough bytes, especially when needed_bytes is small + new_padding = os.urandom(needed_bytes + 5) + new_padding = new_padding.replace(b('\x00'), b('')) + padding = padding + new_padding[:needed_bytes] + + assert len(padding) == padding_length + + return b('').join([b('\x00\x02'), + padding, + b('\x00'), + message]) + + +def _pad_for_signing(message, target_length): + r'''Pads the message for signing, returning the padded message. + + The padding is always a repetition of FF bytes. + + :return: 00 01 PADDING 00 MESSAGE + + >>> block = _pad_for_signing('hello', 16) + >>> len(block) + 16 + >>> block[0:2] + '\x00\x01' + >>> block[-6:] + '\x00hello' + >>> block[2:-6] + '\xff\xff\xff\xff\xff\xff\xff\xff' + + ''' + + max_msglength = target_length - 11 + msglength = len(message) + + if msglength > max_msglength: + raise OverflowError('%i bytes needed for message, but there is only' + ' space for %i' % (msglength, max_msglength)) + + padding_length = target_length - msglength - 3 + + return b('').join([b('\x00\x01'), + padding_length * b('\xff'), + b('\x00'), + message]) + + +def encrypt(message, pub_key): + '''Encrypts the given message using PKCS#1 v1.5 + + :param message: the message to encrypt. Must be a byte string no longer than + ``k-11`` bytes, where ``k`` is the number of bytes needed to encode + the ``n`` component of the public key. + :param pub_key: the :py:class:`rsa.PublicKey` to encrypt with. + :raise OverflowError: when the message is too large to fit in the padded + block. + + >>> from rsa import key, common + >>> (pub_key, priv_key) = key.newkeys(256) + >>> message = 'hello' + >>> crypto = encrypt(message, pub_key) + + The crypto text should be just as long as the public key 'n' component: + + >>> len(crypto) == common.byte_size(pub_key.n) + True + + ''' + + keylength = common.byte_size(pub_key.n) + padded = _pad_for_encryption(message, keylength) + + payload = transform.bytes2int(padded) + encrypted = core.encrypt_int(payload, pub_key.e, pub_key.n) + block = transform.int2bytes(encrypted, keylength) + + return block + +def decrypt(crypto, priv_key): + r'''Decrypts the given message using PKCS#1 v1.5 + + The decryption is considered 'failed' when the resulting cleartext doesn't + start with the bytes 00 02, or when the 00 byte between the padding and + the message cannot be found. + + :param crypto: the crypto text as returned by :py:func:`rsa.encrypt` + :param priv_key: the :py:class:`rsa.PrivateKey` to decrypt with. + :raise DecryptionError: when the decryption fails. No details are given as + to why the code thinks the decryption fails, as this would leak + information about the private key. + + + >>> import rsa + >>> (pub_key, priv_key) = rsa.newkeys(256) + + It works with strings: + + >>> crypto = encrypt('hello', pub_key) + >>> decrypt(crypto, priv_key) + 'hello' + + And with binary data: + + >>> crypto = encrypt('\x00\x00\x00\x00\x01', pub_key) + >>> decrypt(crypto, priv_key) + '\x00\x00\x00\x00\x01' + + Altering the encrypted information will *likely* cause a + :py:class:`rsa.pkcs1.DecryptionError`. If you want to be *sure*, use + :py:func:`rsa.sign`. + + + .. warning:: + + Never display the stack trace of a + :py:class:`rsa.pkcs1.DecryptionError` exception. It shows where in the + code the exception occurred, and thus leaks information about the key. + It's only a tiny bit of information, but every bit makes cracking the + keys easier. + + >>> crypto = encrypt('hello', pub_key) + >>> crypto = crypto[0:5] + 'X' + crypto[6:] # change a byte + >>> decrypt(crypto, priv_key) + Traceback (most recent call last): + ... + DecryptionError: Decryption failed + + ''' + + blocksize = common.byte_size(priv_key.n) + encrypted = transform.bytes2int(crypto) + decrypted = core.decrypt_int(encrypted, priv_key.d, priv_key.n) + cleartext = transform.int2bytes(decrypted, blocksize) + + # If we can't find the cleartext marker, decryption failed. + if cleartext[0:2] != b('\x00\x02'): + raise DecryptionError('Decryption failed') + + # Find the 00 separator between the padding and the message + try: + sep_idx = cleartext.index(b('\x00'), 2) + except ValueError: + raise DecryptionError('Decryption failed') + + return cleartext[sep_idx+1:] + +def sign(message, priv_key, hash): + '''Signs the message with the private key. + + Hashes the message, then signs the hash with the given key. This is known + as a "detached signature", because the message itself isn't altered. + + :param message: the message to sign. Can be an 8-bit string or a file-like + object. If ``message`` has a ``read()`` method, it is assumed to be a + file-like object. + :param priv_key: the :py:class:`rsa.PrivateKey` to sign with + :param hash: the hash method used on the message. Use 'MD5', 'SHA-1', + 'SHA-256', 'SHA-384' or 'SHA-512'. + :return: a message signature block. + :raise OverflowError: if the private key is too small to contain the + requested hash. + + ''' + + # Get the ASN1 code for this hash method + if hash not in HASH_ASN1: + raise ValueError('Invalid hash method: %s' % hash) + asn1code = HASH_ASN1[hash] + + # Calculate the hash + hash = _hash(message, hash) + + # Encrypt the hash with the private key + cleartext = asn1code + hash + keylength = common.byte_size(priv_key.n) + padded = _pad_for_signing(cleartext, keylength) + + payload = transform.bytes2int(padded) + encrypted = core.encrypt_int(payload, priv_key.d, priv_key.n) + block = transform.int2bytes(encrypted, keylength) + + return block + +def verify(message, signature, pub_key): + '''Verifies that the signature matches the message. + + The hash method is detected automatically from the signature. + + :param message: the signed message. Can be an 8-bit string or a file-like + object. If ``message`` has a ``read()`` method, it is assumed to be a + file-like object. + :param signature: the signature block, as created with :py:func:`rsa.sign`. + :param pub_key: the :py:class:`rsa.PublicKey` of the person signing the message. + :raise VerificationError: when the signature doesn't match the message. + + .. warning:: + + Never display the stack trace of a + :py:class:`rsa.pkcs1.VerificationError` exception. It shows where in + the code the exception occurred, and thus leaks information about the + key. It's only a tiny bit of information, but every bit makes cracking + the keys easier. + + ''' + + blocksize = common.byte_size(pub_key.n) + encrypted = transform.bytes2int(signature) + decrypted = core.decrypt_int(encrypted, pub_key.e, pub_key.n) + clearsig = transform.int2bytes(decrypted, blocksize) + + # If we can't find the signature marker, verification failed. + if clearsig[0:2] != b('\x00\x01'): + raise VerificationError('Verification failed') + + # Find the 00 separator between the padding and the payload + try: + sep_idx = clearsig.index(b('\x00'), 2) + except ValueError: + raise VerificationError('Verification failed') + + # Get the hash and the hash method + (method_name, signature_hash) = _find_method_hash(clearsig[sep_idx+1:]) + message_hash = _hash(message, method_name) + + # Compare the real hash to the hash in the signature + if message_hash != signature_hash: + raise VerificationError('Verification failed') + + return True + +def _hash(message, method_name): + '''Returns the message digest. + + :param message: the signed message. Can be an 8-bit string or a file-like + object. If ``message`` has a ``read()`` method, it is assumed to be a + file-like object. + :param method_name: the hash method, must be a key of + :py:const:`HASH_METHODS`. + + ''' + + if method_name not in HASH_METHODS: + raise ValueError('Invalid hash method: %s' % method_name) + + method = HASH_METHODS[method_name] + hasher = method() + + if hasattr(message, 'read') and hasattr(message.read, '__call__'): + # read as 1K blocks + for block in varblock.yield_fixedblocks(message, 1024): + hasher.update(block) + else: + # hash the message object itself. + hasher.update(message) + + return hasher.digest() + + +def _find_method_hash(method_hash): + '''Finds the hash method and the hash itself. + + :param method_hash: ASN1 code for the hash method concatenated with the + hash itself. + + :return: tuple (method, hash) where ``method`` is the used hash method, and + ``hash`` is the hash itself. + + :raise VerificationFailed: when the hash method cannot be found + + ''' + + for (hashname, asn1code) in HASH_ASN1.items(): + if not method_hash.startswith(asn1code): + continue + + return (hashname, method_hash[len(asn1code):]) + + raise VerificationError('Verification failed') + + +__all__ = ['encrypt', 'decrypt', 'sign', 'verify', + 'DecryptionError', 'VerificationError', 'CryptoError'] + +if __name__ == '__main__': + print('Running doctests 1000x or until failure') + import doctest + + for count in range(1000): + (failures, tests) = doctest.testmod() + if failures: + break + + if count and count % 100 == 0: + print('%i times' % count) + + print('Doctests done') diff --git a/src/lib/rsa/prime.py b/src/lib/rsa/prime.py new file mode 100644 index 00000000..7422eb1d --- /dev/null +++ b/src/lib/rsa/prime.py @@ -0,0 +1,166 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Numerical functions related to primes. + +Implementation based on the book Algorithm Design by Michael T. Goodrich and +Roberto Tamassia, 2002. +''' + +__all__ = [ 'getprime', 'are_relatively_prime'] + +import rsa.randnum + +def gcd(p, q): + '''Returns the greatest common divisor of p and q + + >>> gcd(48, 180) + 12 + ''' + + while q != 0: + if p < q: (p,q) = (q,p) + (p,q) = (q, p % q) + return p + + +def jacobi(a, b): + '''Calculates the value of the Jacobi symbol (a/b) where both a and b are + positive integers, and b is odd + + :returns: -1, 0 or 1 + ''' + + assert a > 0 + assert b > 0 + + if a == 0: return 0 + result = 1 + while a > 1: + if a & 1: + if ((a-1)*(b-1) >> 2) & 1: + result = -result + a, b = b % a, a + else: + if (((b * b) - 1) >> 3) & 1: + result = -result + a >>= 1 + if a == 0: return 0 + return result + +def jacobi_witness(x, n): + '''Returns False if n is an Euler pseudo-prime with base x, and + True otherwise. + ''' + + j = jacobi(x, n) % n + + f = pow(x, n >> 1, n) + + if j == f: return False + return True + +def randomized_primality_testing(n, k): + '''Calculates whether n is composite (which is always correct) or + prime (which is incorrect with error probability 2**-k) + + Returns False if the number is composite, and True if it's + probably prime. + ''' + + # 50% of Jacobi-witnesses can report compositness of non-prime numbers + + # The implemented algorithm using the Jacobi witness function has error + # probability q <= 0.5, according to Goodrich et. al + # + # q = 0.5 + # t = int(math.ceil(k / log(1 / q, 2))) + # So t = k / log(2, 2) = k / 1 = k + # this means we can use range(k) rather than range(t) + + for _ in range(k): + x = rsa.randnum.randint(n-1) + if jacobi_witness(x, n): return False + + return True + +def is_prime(number): + '''Returns True if the number is prime, and False otherwise. + + >>> is_prime(42) + False + >>> is_prime(41) + True + ''' + + return randomized_primality_testing(number, 6) + +def getprime(nbits): + '''Returns a prime number that can be stored in 'nbits' bits. + + >>> p = getprime(128) + >>> is_prime(p-1) + False + >>> is_prime(p) + True + >>> is_prime(p+1) + False + + >>> from rsa import common + >>> common.bit_size(p) == 128 + True + + ''' + + while True: + integer = rsa.randnum.read_random_int(nbits) + + # Make sure it's odd + integer |= 1 + + # Test for primeness + if is_prime(integer): + return integer + + # Retry if not prime + + +def are_relatively_prime(a, b): + '''Returns True if a and b are relatively prime, and False if they + are not. + + >>> are_relatively_prime(2, 3) + 1 + >>> are_relatively_prime(2, 4) + 0 + ''' + + d = gcd(a, b) + return (d == 1) + +if __name__ == '__main__': + print('Running doctests 1000x or until failure') + import doctest + + for count in range(1000): + (failures, tests) = doctest.testmod() + if failures: + break + + if count and count % 100 == 0: + print('%i times' % count) + + print('Doctests done') diff --git a/src/lib/rsa/randnum.py b/src/lib/rsa/randnum.py new file mode 100644 index 00000000..0e782744 --- /dev/null +++ b/src/lib/rsa/randnum.py @@ -0,0 +1,85 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Functions for generating random numbers.''' + +# Source inspired by code by Yesudeep Mangalapilly + +import os + +from rsa import common, transform +from rsa._compat import byte + +def read_random_bits(nbits): + '''Reads 'nbits' random bits. + + If nbits isn't a whole number of bytes, an extra byte will be appended with + only the lower bits set. + ''' + + nbytes, rbits = divmod(nbits, 8) + + # Get the random bytes + randomdata = os.urandom(nbytes) + + # Add the remaining random bits + if rbits > 0: + randomvalue = ord(os.urandom(1)) + randomvalue >>= (8 - rbits) + randomdata = byte(randomvalue) + randomdata + + return randomdata + + +def read_random_int(nbits): + '''Reads a random integer of approximately nbits bits. + ''' + + randomdata = read_random_bits(nbits) + value = transform.bytes2int(randomdata) + + # Ensure that the number is large enough to just fill out the required + # number of bits. + value |= 1 << (nbits - 1) + + return value + +def randint(maxvalue): + '''Returns a random integer x with 1 <= x <= maxvalue + + May take a very long time in specific situations. If maxvalue needs N bits + to store, the closer maxvalue is to (2 ** N) - 1, the faster this function + is. + ''' + + bit_size = common.bit_size(maxvalue) + + tries = 0 + while True: + value = read_random_int(bit_size) + if value <= maxvalue: + break + + if tries and tries % 10 == 0: + # After a lot of tries to get the right number of bits but still + # smaller than maxvalue, decrease the number of bits by 1. That'll + # dramatically increase the chances to get a large enough number. + bit_size -= 1 + tries += 1 + + return value + + diff --git a/src/lib/rsa/transform.py b/src/lib/rsa/transform.py new file mode 100644 index 00000000..c740b2d2 --- /dev/null +++ b/src/lib/rsa/transform.py @@ -0,0 +1,220 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Data transformation functions. + +From bytes to a number, number to bytes, etc. +''' + +from __future__ import absolute_import + +try: + # We'll use psyco if available on 32-bit architectures to speed up code. + # Using psyco (if available) cuts down the execution time on Python 2.5 + # at least by half. + import psyco + psyco.full() +except ImportError: + pass + +import binascii +from struct import pack +from rsa import common +from rsa._compat import is_integer, b, byte, get_word_alignment, ZERO_BYTE, EMPTY_BYTE + + +def bytes2int(raw_bytes): + r'''Converts a list of bytes or an 8-bit string to an integer. + + When using unicode strings, encode it to some encoding like UTF8 first. + + >>> (((128 * 256) + 64) * 256) + 15 + 8405007 + >>> bytes2int('\x80@\x0f') + 8405007 + + ''' + + return int(binascii.hexlify(raw_bytes), 16) + + +def _int2bytes(number, block_size=None): + r'''Converts a number to a string of bytes. + + Usage:: + + >>> _int2bytes(123456789) + '\x07[\xcd\x15' + >>> bytes2int(_int2bytes(123456789)) + 123456789 + + >>> _int2bytes(123456789, 6) + '\x00\x00\x07[\xcd\x15' + >>> bytes2int(_int2bytes(123456789, 128)) + 123456789 + + >>> _int2bytes(123456789, 3) + Traceback (most recent call last): + ... + OverflowError: Needed 4 bytes for number, but block size is 3 + + @param number: the number to convert + @param block_size: the number of bytes to output. If the number encoded to + bytes is less than this, the block will be zero-padded. When not given, + the returned block is not padded. + + @throws OverflowError when block_size is given and the number takes up more + bytes than fit into the block. + ''' + # Type checking + if not is_integer(number): + raise TypeError("You must pass an integer for 'number', not %s" % + number.__class__) + + if number < 0: + raise ValueError('Negative numbers cannot be used: %i' % number) + + # Do some bounds checking + if number == 0: + needed_bytes = 1 + raw_bytes = [ZERO_BYTE] + else: + needed_bytes = common.byte_size(number) + raw_bytes = [] + + # You cannot compare None > 0 in Python 3x. It will fail with a TypeError. + if block_size and block_size > 0: + if needed_bytes > block_size: + raise OverflowError('Needed %i bytes for number, but block size ' + 'is %i' % (needed_bytes, block_size)) + + # Convert the number to bytes. + while number > 0: + raw_bytes.insert(0, byte(number & 0xFF)) + number >>= 8 + + # Pad with zeroes to fill the block + if block_size and block_size > 0: + padding = (block_size - needed_bytes) * ZERO_BYTE + else: + padding = EMPTY_BYTE + + return padding + EMPTY_BYTE.join(raw_bytes) + + +def bytes_leading(raw_bytes, needle=ZERO_BYTE): + ''' + Finds the number of prefixed byte occurrences in the haystack. + + Useful when you want to deal with padding. + + :param raw_bytes: + Raw bytes. + :param needle: + The byte to count. Default \000. + :returns: + The number of leading needle bytes. + ''' + leading = 0 + # Indexing keeps compatibility between Python 2.x and Python 3.x + _byte = needle[0] + for x in raw_bytes: + if x == _byte: + leading += 1 + else: + break + return leading + + +def int2bytes(number, fill_size=None, chunk_size=None, overflow=False): + ''' + Convert an unsigned integer to bytes (base-256 representation):: + + Does not preserve leading zeros if you don't specify a chunk size or + fill size. + + .. NOTE: + You must not specify both fill_size and chunk_size. Only one + of them is allowed. + + :param number: + Integer value + :param fill_size: + If the optional fill size is given the length of the resulting + byte string is expected to be the fill size and will be padded + with prefix zero bytes to satisfy that length. + :param chunk_size: + If optional chunk size is given and greater than zero, pad the front of + the byte string with binary zeros so that the length is a multiple of + ``chunk_size``. + :param overflow: + ``False`` (default). If this is ``True``, no ``OverflowError`` + will be raised when the fill_size is shorter than the length + of the generated byte sequence. Instead the byte sequence will + be returned as is. + :returns: + Raw bytes (base-256 representation). + :raises: + ``OverflowError`` when fill_size is given and the number takes up more + bytes than fit into the block. This requires the ``overflow`` + argument to this function to be set to ``False`` otherwise, no + error will be raised. + ''' + if number < 0: + raise ValueError("Number must be an unsigned integer: %d" % number) + + if fill_size and chunk_size: + raise ValueError("You can either fill or pad chunks, but not both") + + # Ensure these are integers. + number & 1 + + raw_bytes = b('') + + # Pack the integer one machine word at a time into bytes. + num = number + word_bits, _, max_uint, pack_type = get_word_alignment(num) + pack_format = ">%s" % pack_type + while num > 0: + raw_bytes = pack(pack_format, num & max_uint) + raw_bytes + num >>= word_bits + # Obtain the index of the first non-zero byte. + zero_leading = bytes_leading(raw_bytes) + if number == 0: + raw_bytes = ZERO_BYTE + # De-padding. + raw_bytes = raw_bytes[zero_leading:] + + length = len(raw_bytes) + if fill_size and fill_size > 0: + if not overflow and length > fill_size: + raise OverflowError( + "Need %d bytes for number, but fill size is %d" % + (length, fill_size) + ) + raw_bytes = raw_bytes.rjust(fill_size, ZERO_BYTE) + elif chunk_size and chunk_size > 0: + remainder = length % chunk_size + if remainder: + padding_size = chunk_size - remainder + raw_bytes = raw_bytes.rjust(length + padding_size, ZERO_BYTE) + return raw_bytes + + +if __name__ == '__main__': + import doctest + doctest.testmod() + diff --git a/src/lib/rsa/util.py b/src/lib/rsa/util.py new file mode 100644 index 00000000..5bbb70be --- /dev/null +++ b/src/lib/rsa/util.py @@ -0,0 +1,81 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''Utility functions.''' + +from __future__ import with_statement, print_function + +import sys +from optparse import OptionParser + +import rsa.key + +def private_to_public(): + '''Reads a private key and outputs the corresponding public key.''' + + # Parse the CLI options + parser = OptionParser(usage='usage: %prog [options]', + description='Reads a private key and outputs the ' + 'corresponding public key. Both private and public keys use ' + 'the format described in PKCS#1 v1.5') + + parser.add_option('-i', '--input', dest='infilename', type='string', + help='Input filename. Reads from stdin if not specified') + parser.add_option('-o', '--output', dest='outfilename', type='string', + help='Output filename. Writes to stdout of not specified') + + parser.add_option('--inform', dest='inform', + help='key format of input - default PEM', + choices=('PEM', 'DER'), default='PEM') + + parser.add_option('--outform', dest='outform', + help='key format of output - default PEM', + choices=('PEM', 'DER'), default='PEM') + + (cli, cli_args) = parser.parse_args(sys.argv) + + # Read the input data + if cli.infilename: + print('Reading private key from %s in %s format' % \ + (cli.infilename, cli.inform), file=sys.stderr) + with open(cli.infilename, 'rb') as infile: + in_data = infile.read() + else: + print('Reading private key from stdin in %s format' % cli.inform, + file=sys.stderr) + in_data = sys.stdin.read().encode('ascii') + + assert type(in_data) == bytes, type(in_data) + + + # Take the public fields and create a public key + priv_key = rsa.key.PrivateKey.load_pkcs1(in_data, cli.inform) + pub_key = rsa.key.PublicKey(priv_key.n, priv_key.e) + + # Save to the output file + out_data = pub_key.save_pkcs1(cli.outform) + + if cli.outfilename: + print('Writing public key to %s in %s format' % \ + (cli.outfilename, cli.outform), file=sys.stderr) + with open(cli.outfilename, 'wb') as outfile: + outfile.write(out_data) + else: + print('Writing public key to stdout in %s format' % cli.outform, + file=sys.stderr) + sys.stdout.write(out_data.decode('ascii')) + + diff --git a/src/lib/rsa/varblock.py b/src/lib/rsa/varblock.py new file mode 100644 index 00000000..c7d96ae6 --- /dev/null +++ b/src/lib/rsa/varblock.py @@ -0,0 +1,155 @@ +# -*- coding: utf-8 -*- +# +# Copyright 2011 Sybren A. Stüvel +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +'''VARBLOCK file support + +The VARBLOCK file format is as follows, where || denotes byte concatenation: + + FILE := VERSION || BLOCK || BLOCK ... + + BLOCK := LENGTH || DATA + + LENGTH := varint-encoded length of the subsequent data. Varint comes from + Google Protobuf, and encodes an integer into a variable number of bytes. + Each byte uses the 7 lowest bits to encode the value. The highest bit set + to 1 indicates the next byte is also part of the varint. The last byte will + have this bit set to 0. + +This file format is called the VARBLOCK format, in line with the varint format +used to denote the block sizes. + +''' + +from rsa._compat import byte, b + + +ZERO_BYTE = b('\x00') +VARBLOCK_VERSION = 1 + +def read_varint(infile): + '''Reads a varint from the file. + + When the first byte to be read indicates EOF, (0, 0) is returned. When an + EOF occurs when at least one byte has been read, an EOFError exception is + raised. + + @param infile: the file-like object to read from. It should have a read() + method. + @returns (varint, length), the read varint and the number of read bytes. + ''' + + varint = 0 + read_bytes = 0 + + while True: + char = infile.read(1) + if len(char) == 0: + if read_bytes == 0: + return (0, 0) + raise EOFError('EOF while reading varint, value is %i so far' % + varint) + + byte = ord(char) + varint += (byte & 0x7F) << (7 * read_bytes) + + read_bytes += 1 + + if not byte & 0x80: + return (varint, read_bytes) + + +def write_varint(outfile, value): + '''Writes a varint to a file. + + @param outfile: the file-like object to write to. It should have a write() + method. + @returns the number of written bytes. + ''' + + # there is a big difference between 'write the value 0' (this case) and + # 'there is nothing left to write' (the false-case of the while loop) + + if value == 0: + outfile.write(ZERO_BYTE) + return 1 + + written_bytes = 0 + while value > 0: + to_write = value & 0x7f + value = value >> 7 + + if value > 0: + to_write |= 0x80 + + outfile.write(byte(to_write)) + written_bytes += 1 + + return written_bytes + + +def yield_varblocks(infile): + '''Generator, yields each block in the input file. + + @param infile: file to read, is expected to have the VARBLOCK format as + described in the module's docstring. + @yields the contents of each block. + ''' + + # Check the version number + first_char = infile.read(1) + if len(first_char) == 0: + raise EOFError('Unable to read VARBLOCK version number') + + version = ord(first_char) + if version != VARBLOCK_VERSION: + raise ValueError('VARBLOCK version %i not supported' % version) + + while True: + (block_size, read_bytes) = read_varint(infile) + + # EOF at block boundary, that's fine. + if read_bytes == 0 and block_size == 0: + break + + block = infile.read(block_size) + + read_size = len(block) + if read_size != block_size: + raise EOFError('Block size is %i, but could read only %i bytes' % + (block_size, read_size)) + + yield block + + +def yield_fixedblocks(infile, blocksize): + '''Generator, yields each block of ``blocksize`` bytes in the input file. + + :param infile: file to read and separate in blocks. + :returns: a generator that yields the contents of each block + ''' + + while True: + block = infile.read(blocksize) + + read_bytes = len(block) + if read_bytes == 0: + break + + yield block + + if read_bytes < blocksize: + break + diff --git a/src/main.py b/src/main.py index 98cde932..c45dd0e5 100644 --- a/src/main.py +++ b/src/main.py @@ -38,12 +38,22 @@ if not os.path.isfile("%s/users.json" % config.data_dir): # Setup logging if config.action == "main": + from util import helper + log_file_path = "%s/debug.log" % config.log_dir + try: + helper.openLocked(log_file_path, "a") + except IOError as err: + print "Can't lock %s file, your ZeroNet client is probably already running, exiting... (%s)" % (log_file_path, err) + sys.exit() + if os.path.isfile("%s/debug.log" % config.log_dir): # Simple logrotate if os.path.isfile("%s/debug-last.log" % config.log_dir): os.unlink("%s/debug-last.log" % config.log_dir) os.rename("%s/debug.log" % config.log_dir, "%s/debug-last.log" % config.log_dir) - logging.basicConfig(format='[%(asctime)s] %(levelname)-8s %(name)s %(message)s', - level=logging.DEBUG, filename="%s/debug.log" % config.log_dir) + logging.basicConfig( + format='[%(asctime)s] %(levelname)-8s %(name)s %(message)s', + level=logging.DEBUG, stream=helper.openLocked(log_file_path, "a") + ) else: logging.basicConfig(level=logging.DEBUG, stream=open(os.devnull, "w")) # No file logging if action is not main @@ -84,11 +94,18 @@ if config.proxy: import urllib2 logging.info("Patching sockets to socks proxy: %s" % config.proxy) config.fileserver_ip = '127.0.0.1' # Do not accept connections anywhere but localhost - SocksProxy.monkeyPath(*config.proxy.split(":")) - + SocksProxy.monkeyPatch(*config.proxy.split(":")) +elif config.tor == "always": + from util import SocksProxy + import urllib2 + logging.info("Patching sockets to tor socks proxy: %s" % config.tor_proxy) + config.fileserver_ip = '127.0.0.1' # Do not accept connections anywhere but localhost + SocksProxy.monkeyPatch(*config.tor_proxy.split(":")) + config.disable_udp = True # -- Actions -- + @PluginManager.acceptPlugins class Actions(object): def call(self, function_name, kwargs): @@ -237,6 +254,7 @@ class Actions(object): logging.info("Creating FileServer....") file_server = FileServer() + site.connection_server = file_server file_server_thread = gevent.spawn(file_server.start, check_sites=False) # Dont check every site integrity time.sleep(0) @@ -265,7 +283,6 @@ class Actions(object): logging.info(my_peer.request("sitePublish", {"site": site.address, "inner_path": inner_path})) logging.info("Done.") - # Crypto commands def cryptPrivatekeyToAddress(self, privatekey=None): from Crypt import CryptBitcoin @@ -287,6 +304,8 @@ class Actions(object): global file_server from Connection import ConnectionServer file_server = ConnectionServer("127.0.0.1", 1234) + from Crypt import CryptConnection + CryptConnection.manager.loadCerts() from Peer import Peer logging.info("Pinging 5 times peer: %s:%s..." % (peer_ip, int(peer_port))) @@ -296,29 +315,43 @@ class Actions(object): print peer.ping(), print "Response time: %.3fs (crypt: %s)" % (time.time() - s, peer.connection.crypt) time.sleep(1) + peer.remove() + print "Reconnect test..." + peer = Peer(peer_ip, peer_port) + for i in range(5): + s = time.time() + print peer.ping(), + print "Response time: %.3fs (crypt: %s)" % (time.time() - s, peer.connection.crypt) + time.sleep(1) def peerGetFile(self, peer_ip, peer_port, site, filename, benchmark=False): logging.info("Opening a simple connection server") global file_server from Connection import ConnectionServer - file_server = ConnectionServer() + file_server = ConnectionServer("127.0.0.1", 1234) + from Crypt import CryptConnection + CryptConnection.manager.loadCerts() from Peer import Peer logging.info("Getting %s/%s from peer: %s:%s..." % (site, filename, peer_ip, peer_port)) peer = Peer(peer_ip, peer_port) s = time.time() - peer.getFile(site, filename) if benchmark: for i in range(10): - print peer.getFile(site, filename), + peer.getFile(site, filename), print "Response time: %.3fs" % (time.time() - s) raw_input("Check memory") + else: + print peer.getFile(site, filename).read() def peerCmd(self, peer_ip, peer_port, cmd, parameters): logging.info("Opening a simple connection server") global file_server from Connection import ConnectionServer file_server = ConnectionServer() + from Crypt import CryptConnection + CryptConnection.manager.loadCerts() + from Peer import Peer peer = Peer(peer_ip, peer_port) diff --git a/src/util/SocksProxy.py b/src/util/SocksProxy.py index a11a385d..7a99e2aa 100644 --- a/src/util/SocksProxy.py +++ b/src/util/SocksProxy.py @@ -4,8 +4,12 @@ from lib.PySocks import socks def create_connection(address, timeout=None, source_address=None): - sock = socks.socksocket() - sock.connect(address) + if address == "127.0.0.1": + sock = socket.socket_noproxy(socket.AF_INET, socket.SOCK_STREAM) + sock.connect(address) + else: + sock = socks.socksocket() + sock.connect(address) return sock @@ -14,9 +18,9 @@ def getaddrinfo(*args): return [(socket.AF_INET, socket.SOCK_STREAM, 6, '', (args[0], args[1]))] -def monkeyPath(proxy_ip, proxy_port): - print proxy_ip, proxy_port +def monkeyPatch(proxy_ip, proxy_port): socks.setdefaultproxy(socks.PROXY_TYPE_SOCKS5, proxy_ip, int(proxy_port)) + socket.socket_noproxy = socket.socket socket.socket = socks.socksocket socket.create_connection = create_connection socket.getaddrinfo = getaddrinfo diff --git a/src/util/UpnpPunch.py b/src/util/UpnpPunch.py index eb4b3f16..d7caea1e 100644 --- a/src/util/UpnpPunch.py +++ b/src/util/UpnpPunch.py @@ -36,7 +36,10 @@ def _m_search_ssdp(local_ip): sock.bind((local_ip, 10000)) sock.sendto(ssdp_request, ('239.255.255.250', 1900)) - sock.settimeout(5) + if local_ip == "127.0.0.1": + sock.settimeout(1) + else: + sock.settimeout(5) try: return sock.recv(2048) @@ -233,6 +236,9 @@ def open_port(port=15441, desc="UpnpPunch"): if __name__ == "__main__": from gevent import monkey monkey.patch_socket() + import time + s = time.time() logging.getLogger().setLevel(logging.DEBUG) print open_port(15441, "ZeroNet") + print "Done in", time.time()-s diff --git a/src/util/helper.py b/src/util/helper.py index 9750af53..937d45da 100644 --- a/src/util/helper.py +++ b/src/util/helper.py @@ -4,18 +4,41 @@ import struct import re import collections import time +import logging +import base64 def atomicWrite(dest, content, mode="w"): - with open(dest + "-new", mode) as f: - f.write(content) - f.flush() - os.fsync(f.fileno()) - if os.path.isfile(dest + "-old"): # Previous incomplete write - os.rename(dest + "-old", dest + "-old-%s" % time.time()) - os.rename(dest, dest + "-old") - os.rename(dest + "-new", dest) - os.unlink(dest + "-old") + try: + with open(dest + "-new", mode) as f: + f.write(content) + f.flush() + os.fsync(f.fileno()) + if os.path.isfile(dest + "-old"): # Previous incomplete write + os.rename(dest + "-old", dest + "-old-%s" % time.time()) + os.rename(dest, dest + "-old") + os.rename(dest + "-new", dest) + os.unlink(dest + "-old") + return True + except Exception, err: + from Debug import Debug + logging.error( + "File %s write failed: %s, reverting..." % + (dest, Debug.formatException(err)) + ) + if os.path.isfile(dest + "-old") and not os.path.isfile(dest): + os.rename(dest + "-old", dest) + return False + + +def openLocked(path, mode="w"): + if os.name == "posix": + import fcntl + f = open(path, mode) + fcntl.flock(f, fcntl.LOCK_EX | fcntl.LOCK_NB) + else: + f = open(path, mode) + return f def shellquote(*args): @@ -25,6 +48,16 @@ def shellquote(*args): return tuple(['"%s"' % arg.replace('"', "") for arg in args]) +def packPeers(peers): + packed_peers = {"ip4": [], "onion": []} + for peer in peers: + if peer.ip.endswith(".onion"): + packed_peers["onion"].append(peer.packMyAddress()) + else: + packed_peers["ip4"].append(peer.packMyAddress()) + return packed_peers + + # ip, port to packed 6byte format def packAddress(ip, port): return socket.inet_aton(ip) + struct.pack("H", port) @@ -32,9 +65,21 @@ def packAddress(ip, port): # From 6byte format to ip, port def unpackAddress(packed): + assert len(packed) == 6, "Invalid length ip4 packed address: %s" % len(packed) return socket.inet_ntoa(packed[0:4]), struct.unpack_from("H", packed, 4)[0] +# onion, port to packed 12byte format +def packOnionAddress(onion, port): + onion = onion.replace(".onion", "") + return base64.b32decode(onion.upper()) + struct.pack("H", port) + + +# From 12byte format to ip, port +def unpackOnionAddress(packed): + return base64.b32encode(packed[0:-2]).lower() + ".onion", struct.unpack("H", packed[-2:])[0] + + # Get dir from file # Return: data/site/content.json -> data/site def getDirname(path): @@ -62,3 +107,34 @@ def mergeDicts(dicts): for key, val in d.iteritems(): back[key].update(val) return dict(back) + + +# Request https url using gevent SSL error workaround +def httpRequest(url, as_file=False): + if url.startswith("http://"): + import urllib + response = urllib.urlopen(url) + else: # Hack to avoid Python gevent ssl errors + import socket + import httplib + import ssl + + host, request = re.match("https://(.*?)(/.*?)$", url).groups() + + conn = httplib.HTTPSConnection(host) + sock = socket.create_connection((conn.host, conn.port), conn.timeout, conn.source_address) + conn.sock = ssl.wrap_socket(sock, conn.key_file, conn.cert_file) + conn.request("GET", request) + response = conn.getresponse() + + if as_file: + import cStringIO as StringIO + data = StringIO.StringIO() + while True: + buff = response.read(1024 * 16) + if not buff: + break + data.write(buff) + return data + else: + return response diff --git a/tools/tor/manual_install.txt b/tools/tor/manual_install.txt new file mode 100644 index 00000000..e571d0f7 --- /dev/null +++ b/tools/tor/manual_install.txt @@ -0,0 +1,34 @@ +Minimum version requred: 0.2.7.5 + +Manual install method for Windows: + - The download/unpack process is automatized on Windows, but if its fails for any reasons follow the next steps! + - Download Expert Bundle from https://www.torproject.org/download/download.html (tor-win32-*.zip) + - Copy everything from the archive's `Tor` directory `tools\tor` and the files from `Data\Tor` to `tools\tor\data` (you need to create it) + - You should get directory structure similar to this: + utils\tor: + │ libeay32.dll + │ libevent-2-0-5.dll + │ libevent_core-2-0-5.dll + │ libevent_extra-2-0-5.dll + │ libgcc_s_sjlj-1.dll + │ libssp-0.dll + │ manual_install.txt + │ ssleay32.dll + │ start.cmd + │ tor.exe + │ torrc + │ zlib1.dll + + utils\tor\data: + │ geoip + │ geoip6 + - Start ZeroNet, it will run and use the utils\tor\tor.exe file + +For other OS: + - Follow install instructions at: https://www.torproject.org/docs/installguide.html.en + - Edit torrc configuration file + - Remove the `#` character from line `ControlPort 9051` + - Restart tor service + - Start ZeroNet + +For more info check: http://zeronet.readthedocs.org/en/latest/faq/ \ No newline at end of file diff --git a/tools/tor/start.cmd b/tools/tor/start.cmd new file mode 100644 index 00000000..83642235 --- /dev/null +++ b/tools/tor/start.cmd @@ -0,0 +1 @@ +tor.exe -f torrc \ No newline at end of file diff --git a/tools/tor/torrc b/tools/tor/torrc new file mode 100644 index 00000000..a6df211c --- /dev/null +++ b/tools/tor/torrc @@ -0,0 +1,13 @@ +# Tor config for ZeroNet + +DataDirectory data +DirReqStatistics 0 +GeoIPFile geoip\geoip +GeoIPv6File geoip\geoip6 + +# Log notice file data\notice.log + +ControlPort 49051 +SOCKSPort 49050 + +CookieAuthentication 1 \ No newline at end of file diff --git a/update.py b/update.py index 3830dd37..e691b791 100644 --- a/update.py +++ b/update.py @@ -42,6 +42,8 @@ def update(): print "Extracting...", zip = zipfile.ZipFile(data) for inner_path in zip.namelist(): + if ".." in inner_path: + continue inner_path = inner_path.replace("\\", "/") # Make sure we have unix path print ".", dest_path = inner_path.replace("ZeroNet-master/", "") diff --git a/zeronet.py b/zeronet.py index 959699c8..fbfda76f 100644 --- a/zeronet.py +++ b/zeronet.py @@ -10,6 +10,7 @@ def main(): main = None try: + sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src/lib")) # External liblary directory sys.path.insert(0, os.path.join(os.path.dirname(__file__), "src")) # Imports relative to src import main main.start()