xch-blockchain/chia/seeder/dns_server.py

293 lines
9.9 KiB
Python

from __future__ import annotations
import asyncio
import ipaddress
import logging
import random
import signal
import traceback
from pathlib import Path
from typing import Any, Dict, List
import aiosqlite
from dnslib import AAAA, CNAME, MX, NS, QTYPE, RR, SOA, A, DNSHeader, DNSRecord
from chia.util.chia_logging import initialize_logging
from chia.util.config import load_config
from chia.util.default_root import DEFAULT_ROOT_PATH
from chia.util.path import path_from_root
SERVICE_NAME = "seeder"
log = logging.getLogger(__name__)
# DNS snippet taken from: https://gist.github.com/pklaus/b5a7876d4d2cf7271873
class DomainName(str):
def __getattr__(self, item):
return DomainName(item + "." + self)
D = None
ns = None
IP = "127.0.0.1"
TTL = None
soa_record = None
ns_records: List[Any] = []
class EchoServerProtocol(asyncio.DatagramProtocol):
def __init__(self, callback):
self.data_queue = asyncio.Queue()
self.callback = callback
asyncio.ensure_future(self.respond())
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
asyncio.ensure_future(self.handler(data, addr))
async def respond(self):
while True:
try:
resp, caller = await self.data_queue.get()
self.transport.sendto(resp, caller)
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
async def handler(self, data, caller):
try:
data = await self.callback(data)
if data is None:
return
await self.data_queue.put((data, caller))
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
class DNSServer:
reliable_peers_v4: List[str]
reliable_peers_v6: List[str]
lock: asyncio.Lock
pointer: int
crawl_db: aiosqlite.Connection
def __init__(self, config: Dict, root_path: Path):
self.reliable_peers_v4 = []
self.reliable_peers_v6 = []
self.lock = asyncio.Lock()
self.pointer_v4 = 0
self.pointer_v6 = 0
crawler_db_path: str = config.get("crawler_db_path", "crawler.db")
self.db_path = path_from_root(root_path, crawler_db_path)
self.db_path.parent.mkdir(parents=True, exist_ok=True)
async def start(self):
# self.crawl_db = await aiosqlite.connect(self.db_path)
# Get a reference to the event loop as we plan to use
# low-level APIs.
loop = asyncio.get_running_loop()
# One protocol instance will be created to serve all
# client requests.
self.transport, self.protocol = await loop.create_datagram_endpoint(
lambda: EchoServerProtocol(self.dns_response), local_addr=("0.0.0.0", 53)
)
self.reliable_task = asyncio.create_task(self.periodically_get_reliable_peers())
async def periodically_get_reliable_peers(self):
sleep_interval = 0
while True:
sleep_interval = min(15, sleep_interval + 1)
await asyncio.sleep(sleep_interval * 60)
try:
# TODO: double check this. It shouldn't take this long to connect.
crawl_db = await aiosqlite.connect(self.db_path, timeout=600)
cursor = await crawl_db.execute(
"SELECT * from good_peers",
)
new_reliable_peers = []
rows = await cursor.fetchall()
await cursor.close()
await crawl_db.close()
for row in rows:
new_reliable_peers.append(row[0])
if len(new_reliable_peers) > 0:
random.shuffle(new_reliable_peers)
async with self.lock:
self.reliable_peers_v4 = []
self.reliable_peers_v6 = []
for peer in new_reliable_peers:
ipv4 = True
try:
_ = ipaddress.IPv4Address(peer)
except ValueError:
ipv4 = False
if ipv4:
self.reliable_peers_v4.append(peer)
else:
try:
_ = ipaddress.IPv6Address(peer)
except ValueError:
continue
self.reliable_peers_v6.append(peer)
self.pointer_v4 = 0
self.pointer_v6 = 0
log.error(
f"Number of reliable peers discovered in dns server:"
f" IPv4 count - {len(self.reliable_peers_v4)}"
f" IPv6 count - {len(self.reliable_peers_v6)}"
)
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
async def get_peers_to_respond(self, ipv4_count, ipv6_count):
peers = []
async with self.lock:
# Append IPv4.
size = len(self.reliable_peers_v4)
if ipv4_count > 0 and size <= ipv4_count:
peers = self.reliable_peers_v4
elif ipv4_count > 0:
peers = [self.reliable_peers_v4[i % size] for i in range(self.pointer_v4, self.pointer_v4 + ipv4_count)]
self.pointer_v4 = (self.pointer_v4 + ipv4_count) % size
# Append IPv6.
size = len(self.reliable_peers_v6)
if ipv6_count > 0 and size <= ipv6_count:
peers = peers + self.reliable_peers_v6
elif ipv6_count > 0:
peers = peers + [
self.reliable_peers_v6[i % size] for i in range(self.pointer_v6, self.pointer_v6 + ipv6_count)
]
self.pointer_v6 = (self.pointer_v6 + ipv6_count) % size
return peers
async def dns_response(self, data):
try:
request = DNSRecord.parse(data)
IPs = [MX(D.mail), soa_record] + ns_records
ipv4_count = 0
ipv6_count = 0
if request.q.qtype == 1:
ipv4_count = 32
elif request.q.qtype == 28:
ipv6_count = 32
elif request.q.qtype == 255:
ipv4_count = 16
ipv6_count = 16
else:
ipv4_count = 32
peers = await self.get_peers_to_respond(ipv4_count, ipv6_count)
if len(peers) == 0:
return None
for peer in peers:
ipv4 = True
try:
_ = ipaddress.IPv4Address(peer)
except ValueError:
ipv4 = False
if ipv4:
IPs.append(A(peer))
else:
try:
_ = ipaddress.IPv6Address(peer)
except ValueError:
continue
IPs.append(AAAA(peer))
reply = DNSRecord(DNSHeader(id=request.header.id, qr=1, aa=len(IPs), ra=1), q=request.q)
records = {
D: IPs,
D.ns1: [A(IP)], # MX and NS records must never point to a CNAME alias (RFC 2181 section 10.3)
D.ns2: [A(IP)],
D.mail: [A(IP)],
D.andrei: [CNAME(D)],
}
qname = request.q.qname
qn = str(qname)
qtype = request.q.qtype
qt = QTYPE[qtype]
if qn == D or qn.endswith("." + D):
for name, rrs in records.items():
if name == qn:
for rdata in rrs:
rqt = rdata.__class__.__name__
if qt in ["*", rqt] or (qt == "ANY" and (rqt == "A" or rqt == "AAAA")):
reply.add_answer(
RR(rname=qname, rtype=getattr(QTYPE, rqt), rclass=1, ttl=TTL, rdata=rdata)
)
for rdata in ns_records:
reply.add_ar(RR(rname=D, rtype=QTYPE.NS, rclass=1, ttl=TTL, rdata=rdata))
reply.add_auth(RR(rname=D, rtype=QTYPE.SOA, rclass=1, ttl=TTL, rdata=soa_record))
return reply.pack()
except Exception as e:
log.error(f"Exception: {e}. Traceback: {traceback.format_exc()}.")
async def serve_dns(config: Dict, root_path: Path):
dns_server = DNSServer(config, root_path)
await dns_server.start()
# TODO: Make this cleaner?
while True:
await asyncio.sleep(3600)
async def kill_processes():
# TODO: implement.
pass
def signal_received():
asyncio.create_task(kill_processes())
async def async_main(config, root_path):
loop = asyncio.get_running_loop()
try:
loop.add_signal_handler(signal.SIGINT, signal_received)
loop.add_signal_handler(signal.SIGTERM, signal_received)
except NotImplementedError:
log.info("signal handlers unsupported")
await serve_dns(config, root_path)
def main():
root_path = DEFAULT_ROOT_PATH
config = load_config(root_path, "config.yaml", SERVICE_NAME)
initialize_logging(SERVICE_NAME, config["logging"], root_path)
global D
global ns
global TTL
global soa_record
global ns_records
D = DomainName(config["domain_name"])
ns = DomainName(config["nameserver"])
TTL = config["ttl"]
soa_record = SOA(
mname=ns, # primary name server
rname=config["soa"]["rname"], # email of the domain administrator
times=(
config["soa"]["serial_number"],
config["soa"]["refresh"],
config["soa"]["retry"],
config["soa"]["expire"],
config["soa"]["minimum"],
),
)
ns_records = [NS(ns)]
asyncio.run(async_main(config=config, root_path=root_path))
if __name__ == "__main__":
main()