1
0
mirror of https://github.com/AR2000AR/openComputers_codes.git synced 2025-09-08 14:41:14 +02:00
Files
openComputers_codes/dns_server/etc/rc.d/dnsd.lua

367 lines
15 KiB
Lua

local dns = require("dns")
local serialization = require("serialization")
local filesystem = require("filesystem")
local io = require("io")
local ipv4Address = require("network.ipv4").address
local socket = require("socket")
local thread = require("thread")
--=============================================================================
---@alias Resource table<number|string>
---@alias Zone table<string,table<dnsRecordClass,table<dnsRecordType,table<Resource>>>> table<recordName,table<dnsRecordClass,table<dnsRecordType,table<Resource>>>>
---@alias Zones table<string,Zone> table<zoneName,Zone>
--=============================================================================
local CONFIG_DIR = "/etc/dnsd/"
---@type UDPSocket
local udpSocket
---@type Zones
local zones = {}
local config = {recursion = false}
local listenerThread
local function printf(...) print(string.format(...)) end
---Load a zone file
---@param path string
---@return Zone,string
local function loadZoneFile(path)
---@param fileHandle file*
---@return string
local function readLine(fileHandle)
local line = ""
repeat
line = fileHandle:read("l")
until not line or line ~= ""
if (not line) then return line end
line = line:match("^([^;]+)"):reverse("^%s*(.*)$"):reverse()
while (line:match("%(") and not line:match("%)")) do
line = line .. " " .. readLine(fileHandle):match("%s*(.*)$")
end
return line
end
local parseRdata = {}
parseRdata[dns.DNSRessourceRecord.TYPE.A] = function(rdata)
return ipv4Address.fromString(rdata)
end
parseRdata[dns.DNSRessourceRecord.TYPE.SOA] = function(rdata)
rdata = rdata:gsub("%(", " "):gsub("%)", " ")
local infos = {}
for info in rdata:gmatch("(%S+)") do
table.insert(infos, info)
end
return infos
end
assert(filesystem.exists(path) and not filesystem.isDirectory(path), "zone file not found")
assert(path:match("%.zone$"), "Not a zone file")
local file = io.open(path)
assert(file, "zone file could not be opened")
---@type Zone
local zone = {}
local zoneName = ""
local lastDomainName = ""
---@type number
local defaultsTTL = 0
---@type dnsRecordClass
local defaultsClass = nil
while true do --see 2 lines under for break condition
local line = readLine(file)
if (not line) then break end
if (line:match("^%$")) then
line = line:sub(2)
if (line:match("^ORIGIN")) then
zoneName = line:match("^%S+%s+(%S+)")
elseif (line:match("^TTL")) then
defaultsTTL = tonumber(line:match("^%S+%s+(%d+)")) or 0
end
else
---@type string,dnsRecordType,string
local name, rtype, rdata
local ttl = defaultsTTL
---@type dnsRecordClass
local class = defaultsClass
if (line:match("^%s")) then
name, line = line:match("^(%s+)(.*)$")
else
name, line = line:match("^(%S+)%s+(.*)$")
end
if (name:match("^%s+$")) then
name = lastDomainName
name = name .. "." .. zoneName
name = name:gsub("%.%.", ".")
elseif (name == "@") then
name = zoneName
lastDomainName = name
else
lastDomainName = name
name = name .. "." .. zoneName
name = name:gsub("%.%.", ".")
end
--try to get ttl and class twice
for _ = 0, 1 do
local tmp = line:match("^%S+")
if (tonumber(tmp)) then
ttl, line = line:match("^(%d+)%s+(.*)")
ttl = tonumber(ttl) or defaultsTTL
defaultsTTL = ttl
elseif (dns.DNSHeaderFlags.CLASS[tmp]) then
class, line = line:match("^(%S+)%s+(.*)")
class = dns.DNSHeaderFlags.CLASS[class]
defaultsClass = class
end
end
rtype, rdata = line:match("(%S+)%s+(.*)")
rtype = dns.DNSRessourceRecord.TYPE[rtype]
if (parseRdata[rtype]) then rdata = parseRdata[rtype](rdata) end
zone[name] = zone[name] or {}
zone[name][class] = zone[name][class] or {}
zone[name][class][rtype] = zone[name][class][rtype] or {}
table.insert(zone[name][class][rtype], {ttl, rdata})
end
end
return zone, zoneName
end
---@param rootDir string
---@return Zones
local function loadAllZoneFiles(rootDir)
local loadedZones = {}
checkArg(1, rootDir, 'string')
if (not filesystem.isDirectory(rootDir)) then
error(string.format("Path %q is not a directory", rootDir), 2)
end
for file in filesystem.list(rootDir) do
if file:match("%.zone$") then
local zone, zname = loadZoneFile(rootDir .. file)
loadedZones[zname] = zone
end
end
return loadedZones
end
---Answer a dns question
---@param question DNSQuestion
---@param originalName? string
---@return table<DNSRessourceRecord> answer, table<DNSRessourceRecord> authoritative,table<DNSRessourceRecord> additional, boolean NXDOMAIN
local function answerQuestion(question, originalName)
if (not originalName) then originalName = question:getName() end
---@type DNSRessourceRecord
local answers, authoritys, additionals, nxdomain = {}, {}, {}, false
local zName, zone = "", nil
--Search the available zones for the zone which is the nearest ancestor to QNAME.
for k, v in pairs(zones) do
if (question:getName():match(k:gsub("[%^%$%(%)%%%.%[%]%*%+%-%?]", "%%%1") .. "$")) then
if (#k > #zName) then
zName = k
zone = v
end
end
end
if (not zones[zName]) then return answers, authoritys, additionals, true end
if (zone[question:getName()]) then
if (zone[question:getName()][question:getClass()][dns.DNSRessourceRecord.TYPE.NS]) then --If a match would take us out of the authoritative data, we have a referral. This happens when we encounter a node with NS RRs marking cuts along the bottom of a zone.
local node = zone[question:getName()][question:getClass()]
local ressourceRecords = node[dns.DNSRessourceRecord.TYPE.NS]
for _, rr in pairs(ressourceRecords) do
table.insert(authoritys, dns.DNSRessourceRecord(question:getName(), question:getClass(), dns.DNSRessourceRecord.TYPE.NS, rr[1], rr[2]))
end
ressourceRecords = zone[question:getName()][question:getClass()][question:getType()]
for _, rr in pairs(ressourceRecords) do
table.insert(additionals, dns.DNSRessourceRecord(question:getName(), question:getClass(), dns.DNSRessourceRecord.TYPE.NS, rr[1], rr[2]))
end
else --If the whole of QNAME is matched, we have found the node.
local node = zone[question:getName()][question:getClass()]
if (node[dns.DNSRessourceRecord.TYPE.CNAME] and question:getType() ~= dns.DNSRessourceRecord.TYPE.CNAME) then
local ressourceRecords = node[dns.DNSRessourceRecord.TYPE.CNAME]
--If the data at the node is a CNAME, and QTYPE doesn't match CNAME, copy the CNAME RR into the answer section of the response, change QNAME to the canonical name in the CNAME RR, and go back to step 1
table.insert(answers, dns.DNSRessourceRecord(
question:getName(), dns.DNSRessourceRecord.TYPE.CNAME, question:getClass(), ressourceRecords[1][1], ressourceRecords[1][2]))
local newQuestion = dns.DNSQuestion(ressourceRecords[1][2], question:getType(), question:getClass())
local answer2, authority2, additionals2, nxdomain2 = answerQuestion(newQuestion, question:getName())
for _, answer in pairs(answer2) do table.insert(answers, answer) end
for _, authority in pairs(authority2) do table.insert(authoritys, authority) end
for _, additional in pairs(additionals2) do table.insert(additionals, additional) end
nxdomain = nxdomain or nxdomain2
elseif (node[question:getType()]) then
local ressourceRecords = node[question:getType()]
for _, rr in pairs(ressourceRecords) do
table.insert(answers, dns.DNSRessourceRecord(question:getName(), question:getType(), question:getClass(), rr[1], rr[2]))
end
end
end
elseif (zone[question:getName():gsub("^%w+", "*")]) then --If at some label, a match is impossible (i.e., the corresponding label does not exist), look to see if a the "*" label exists.
local node = zone[question:getName():gsub("^%w+", "*")][question:getClass()]
if (node) then
local ressourceRecords = node[question:getType()]
for _, rr in pairs(ressourceRecords) do
table.insert(answers, dns.DNSRessourceRecord(question:getName(), question:getType(), question:getClass(), rr[1], rr[2]))
end
end
elseif (not zone[question:getName():gsub("^%w+", "*")] and question:getName() == originalName) then --If the name is original, set an authoritative name error in the response and exit.
nxdomain = true
end
return answers, authoritys, additionals, nxdomain
end
---Handle a dns request
---@param packet string
---@param fromIP string
---@param fromPort number
local function handleMessage(packet, fromIP, fromPort)
local sucess, dnsMessage = pcall(dns.message.parsePacket, packet)
---@type DNSMessage
local dnsResponse = {header = dns.DNSHeader(dnsMessage.header:getId(), dnsMessage.header:getFlags(), 0, 0, 0, 0), questions = {}, answer = {}, authority = {}, additional = {}}
dnsResponse.header:getFlags():setQR(dns.DNSHeaderFlags.QR.REPLY)
if (not sucess or not dnsMessage) then
dnsResponse.header:getFlags():setRCODE(dns.RCODE.FORMERR)
end
--Set or clear the value of recursion available in the response depending on whether the name server is willing to provide recursive service.
dnsResponse.header:getFlags():setRA(config.recursion)
if (config.recursion and dnsMessage.header:getFlags():getRD()) then --recursion requested
--TODO recurse
dnsResponse.header:getFlags():setRCODE(dns.RCODE.NIMPL)
elseif (not config.recursion and dnsMessage.header:getFlags():getRD()) then
--TODO error : cannot recurse
dnsResponse.header:getFlags():setRCODE(dns.RCODE.NIMPL)
else
for i, question in pairs(dnsMessage.questions) do
local answer, authoritative, additional, nxdomain = answerQuestion(question)
if (nxdomain) then dnsResponse.header:getFlags():setRCODE(dns.DNSHeaderFlags.RCODE.NXDOMAIN) end
for _, an in pairs(answer) do table.insert(dnsResponse.answer, an) end
for _, au in pairs(authoritative) do table.insert(dnsResponse.authority, au) end
for _, ad in pairs(additional) do table.insert(dnsResponse.additional, ad) end
if (i == 1) and #authoritative == 0 then dnsResponse.header:getFlags():setAA(true) end
end
end
dnsResponse.header:setAncount(#dnsResponse.answer)
dnsResponse.header:setArcount(#dnsResponse.additional)
dnsResponse.header:setNscount(#dnsResponse.authority)
udpSocket:sendto(dns.message.packDNSMessage(dnsResponse), fromIP, fromPort)
end
---@return boolean
local function getStatus()
if (not udpSocket) then
return false
else
return true
end
end
---@param listenedSocket UDPSocket
local function listenSocket(listenedSocket)
checkArg(1, listenedSocket, 'table')
while true do
local read = table.pack(listenedSocket:receivefrom())
if (read.n == 3) then
local sucess, msg = pcall(handleMessage, table.unpack(read))
if (sucess == false) then
require("event").onError(msg)
end
end
os.sleep()
end
end
--RC METHODS===================================================================
---@diagnostic disable-next-line: lowercase-global
function start()
local reason
---@diagnostic disable-next-line: cast-local-type
udpSocket = assert(socket.udp())
assert(udpSocket:setsockname("*", 51))
zones = loadAllZoneFiles(CONFIG_DIR)
listenerThread = thread.create(listenSocket, udpSocket):detach()
end
---@diagnostic disable-next-line: lowercase-global
function stop()
if (listenerThread) then
listenerThread:kill()
listenerThread = nil
end
if (udpSocket) then
udpSocket:close()
---@diagnostic disable-next-line: cast-local-type
udpSocket = nil
end
---@diagnostic disable-next-line: cast-local-type
zones = nil
require("rc").unload("dnsd")
end
---@diagnostic disable-next-line: lowercase-global
function status()
if (getStatus()) then
print("Running")
else
print("Not running")
end
end
---@diagnostic disable-next-line: lowercase-global
function testZone()
if (not getStatus()) then require("rc").unload("dnsd") end
local callStatus, data = pcall(loadAllZoneFiles, CONFIG_DIR)
if (not callStatus) then
print("Error : ", data)
else
print("ok")
end
end
---@diagnostic disable-next-line: lowercase-global
function printZone(name)
if (not getStatus()) then require("rc").unload("dnsd") end
local function reverseDict(dict)
local r = {}
for k, v in pairs(dict) do
r[v] = k
end
return r
end
local tmpzones = loadAllZoneFiles(CONFIG_DIR)
for zname, zone in pairs(tmpzones) do
print(zname)
for k, v in pairs(zone) do
for class, vv in pairs(v) do
for rtype, datas in pairs(vv) do
for _, rdata in pairs(datas) do
if (rtype == dns.DNSRessourceRecord.TYPE.A) then rdata[2] = ipv4Address.tostring(rdata[2]) end
local recordData = rdata[2]
if (type(rdata[2]) == "table") then
recordData = serialization.serialize(rdata[2])
end
printf("\t%s\t%s\t%s\t%d\t%q", k, reverseDict(dns.DNSHeaderFlags.CLASS)[class], reverseDict(dns.DNSRessourceRecord.TYPE)[rtype], rdata[1], recordData)
end
end
end
end
end
end
---@diagnostic disable-next-line: lowercase-global
function resolve(name, rtype, class)
if (not getStatus()) then require("rc").unload("dnsd") end
class = class or "IN"
if (not getStatus()) then require("rc").unload("dnsd") end
zones = loadAllZoneFiles(CONFIG_DIR)
local question = dns.DNSQuestion(name, dns.DNSRessourceRecord.TYPE[rtype], dns.DNSHeaderFlags.CLASS[class])
print(serialization.serialize(table.pack(answerQuestion(question))))
end