diff --git a/icable_server/icableServer.py b/icable_server/icableServer.py index 37406a7..ef1f88d 100644 --- a/icable_server/icableServer.py +++ b/icable_server/icableServer.py @@ -13,6 +13,7 @@ from pathlib import Path from signal import SIGINT, SIGTERM, signal from string import Template from urllib.parse import parse_qs +import pathlib import models.database import models.firewall @@ -23,7 +24,10 @@ from icable.firewall import * from icable.packet import * from icable.protocol import * -db = models.database.DatabaseHandler('icable.db') +dataPath = pathlib.Path(pathlib.Path.home(),'.local','share','icable') +dataPath.mkdir(parents=True,exist_ok=True) + +db = models.database.DatabaseHandler(pathlib.Path(dataPath,'icable.db')) users = models.user.Users(db) webSessions = models.web.WebSessions(db) networks = models.network.Networks(db) @@ -340,7 +344,7 @@ class WebInterface(BaseHTTPRequestHandler): user_networks.sort() for network in user_networks: dl_button = "" - if(self.user.get_network_permission(network)==1 or models.user.User.SubnetworkPermission.OWNER in self.user.subnetwork_permission): + if(self.user.get_network_permission(network)==1 or models.user.SubnetworkPermission.OWNER in self.user.subnetwork_permission): dl_button = f"""
@@ -357,7 +361,7 @@ class WebInterface(BaseHTTPRequestHandler): #region firewall mapping['firewallDefault']=firewalls[self.user.subnetid].action.name - FIREWALL_ALLOWED = models.user.User.SubnetworkPermission.FIREWALL|models.user.User.SubnetworkPermission.OWNER + FIREWALL_ALLOWED = models.user.SubnetworkPermission.FIREWALL|models.user.SubnetworkPermission.OWNER mapping['firewallRuleList']="" for rule in firewalls[self.user.subnetid].rules: diff --git a/icable_server/models/__init__.py b/icable_server/models/__init__.py index 8f5f6d0..10feaa4 100644 --- a/icable_server/models/__init__.py +++ b/icable_server/models/__init__.py @@ -1,4 +1,5 @@ from . import database +from sqlite3 import Cursor __all__ = ["user","network","firewall","web","database"] @@ -6,4 +7,4 @@ class Model(): """Abstract class for data model calss""" def __init__(self,database:database.DatabaseHandler): - self._database = database \ No newline at end of file + self._database = database diff --git a/icable_server/models/database.py b/icable_server/models/database.py index e031ca9..b84eb34 100644 --- a/icable_server/models/database.py +++ b/icable_server/models/database.py @@ -2,16 +2,12 @@ import threading import pathlib import sqlite3 -class DatabaseHandler(): +class DatabaseHandler(sqlite3.Connection): - def __init__(self,dbPath:str): - self._semaphore = threading.BoundedSemaphore() - dataPath = pathlib.Path(pathlib.Path.home(),'.local','share','icable') - dataPath.mkdir(parents=True,exist_ok=True) - self._database = sqlite3.connect(pathlib.Path(dataPath,dbPath),check_same_thread=False,detect_types=sqlite3.PARSE_DECLTYPES) + def __init__(self,*args,**kwargs): + super().__init__(*args,check_same_thread=False,detect_types=sqlite3.PARSE_DECLTYPES,**kwargs) #self._database.row_factory = Row - self._cursor = self._database.cursor() - self._cursor.executescript(""" + self.executescript(""" BEGIN; CREATE TABLE IF NOT EXISTS "users" ( login TEXT PRIMARY KEY CHECK(typeof("login") = 'text'), @@ -46,28 +42,9 @@ class DatabaseHandler(): END; COMMIT; PRAGMA foreign_keys=ON;""") - - def execute(self,*args,**kwargs): - return self._cursor.execute(*args,**kwargs) - def fetchone(self,*args,**kwargs): - return self._cursor.fetchone(*args,**kwargs) - def fetchall(self,*args,**kwargs): - return self._cursor.fetchall(*args,**kwargs) - def fetchmany(self,*args,**kwargs): - return self._cursor.fetchmany(*args,**kwargs) - def commit(self,*args,**kwargs): - return self.database.commit() - - @property - def semaphore(self)->threading.BoundedSemaphore: - return self._semaphore - - @property - def database(self): - return self._database def __del__(self): - self._cursor.execute("PRAGMA optimize;") - self.database.commit() - self._cursor.close() - self.database.close() + self.rollback() + self.execute("PRAGMA optimize;") + self.commit() + self.close() diff --git a/icable_server/models/firewall.py b/icable_server/models/firewall.py index 8447e51..7304838 100644 --- a/icable_server/models/firewall.py +++ b/icable_server/models/firewall.py @@ -4,13 +4,12 @@ from icable.firewall import Firewall class Firewalls(Model): def __getitem__(self,key:int): - self._database.execute("SELECT pickle FROM firewall WHERE subnetid = ?",(key,)) - res=self._database.fetchone() + cursor=self._database.execute("SELECT pickle FROM firewall WHERE subnetid = ?",(key,)) + res=cursor.fetchone() if(not res): return Firewall() return res[0] def __setitem__(self,key:int,firewall:Firewall): - with self._database.semaphore: - self._database.execute("INSERT OR REPLACE INTO firewall VALUES(?,?)",(key,firewall)) - self._database.commit() \ No newline at end of file + self._database.execute("INSERT OR REPLACE INTO firewall VALUES(?,?)",(key,firewall)) + self._database.commit() \ No newline at end of file diff --git a/icable_server/models/network.py b/icable_server/models/network.py index 87114d4..361e09f 100644 --- a/icable_server/models/network.py +++ b/icable_server/models/network.py @@ -30,37 +30,35 @@ class Network(): class Networks(Model): def getNetworkById(self,id:int): - self._database.execute("SELECT subnetid,network,nmask FROM networks WHERE id = ?",(id,)) - res = self._database.fetchone() + cursor=self._database.execute("SELECT subnetid,network,nmask FROM networks WHERE id = ?",(id,)) + res = cursor.fetchone() if(not res): return False return Network(id,res[0],ipaddress.ip_network((res[1],ipaddress.ip_address(res[2]).compressed))) def getNetworksInSubnet(self,subnetid:int): - self._database.execute("SELECT id,subnetid,network,nmask FROM networks WHERE subnetid = ?",(subnetid,)) - res = self._database.fetchall() + cursor=self._database.execute("SELECT id,subnetid,network,nmask FROM networks WHERE subnetid = ?",(subnetid,)) + res = cursor.fetchall() if(not res): return list[Network]() return [Network(net[0],subnetid,ipaddress.ip_network(net[2],ipaddress.ip_address(net[3]).compressed))for net in res] def getAllNetworks(self): - self._database.execute("SELECT id,subnetid,network,nmask FROM networks") - res = self._database.fetchall() + cursor=self._database.execute("SELECT id,subnetid,network,nmask FROM networks") + res = cursor.fetchall() if(not res): return list[Network]() return [Network(net[0],net[1],ipaddress.ip_network((net[2],ipaddress.ip_address(net[3]).compressed)))for net in res] def createNetwork(self,subnetwork:int,network:ipaddress.IPv4Network) : - with self._database.semaphore: - self._database.execute("INSERT INTO networks(subnetid,network,nmask) VALUES(?,?,?)",(subnetwork,int(network.network_address),int(network.netmask))) - self._database.commit() - self._database.execute("SELECT id from networks WHERE subnetid = ? AND network = ? and nmask = ?",(subnetwork,int(network.network_address),int(network.netmask))) - res = self._database.fetchone() + self._database.execute("INSERT INTO networks(subnetid,network,nmask) VALUES(?,?,?)",(subnetwork,int(network.network_address),int(network.netmask))) + self._database.commit() + cursor=self._database.execute("SELECT id from networks WHERE subnetid = ? AND network = ? and nmask = ?",(subnetwork,int(network.network_address),int(network.netmask))) + res = cursor.fetchone() nt = self.getNetworkById(res[0]) assert(isinstance(nt,Network)) return nt def deleteNetwork(self,network:Network): - with self._database.semaphore: - self._database.execute("DELETE FROM networks WHERE id = ?",(network.id,)) - self._database.commit() + self._database.execute("DELETE FROM networks WHERE id = ?",(network.id,)) + self._database.commit() diff --git a/icable_server/models/user.py b/icable_server/models/user.py index bb21d78..b928efb 100644 --- a/icable_server/models/user.py +++ b/icable_server/models/user.py @@ -20,12 +20,11 @@ class User(Model): @property def password(self)->str: - self._database.execute("SELECT password FROM users WHERE login = ?",(self.login,)) - return self._database.fetchone()[0] + cursor=self._database.execute("SELECT password FROM users WHERE login = ?",(self.login,)) + return cursor.fetchone()[0] @password.setter def password(self,value:str): - with self._database.semaphore: self._database.execute("UPDATE users SET password = ? WHERE login = ?",(self.hashpassword(value),self.login)) self._database.commit() @@ -72,12 +71,11 @@ class User(Model): @property def subnetid(self): - self._database.execute("SELECT subnetid FROM users WHERE login = ?",(self.login,)) - return self._database.fetchone()[0] + cursor=self._database.execute("SELECT subnetid FROM users WHERE login = ?",(self.login,)) + return cursor.fetchone()[0] @subnetid.setter def subnetid(self,value): - with self._database.semaphore: self._database.execute("UPDATE users SET subnetid = ? WHERE login = ?",(value,self.login)) self._database.commit() @@ -87,9 +85,9 @@ class User(Model): @property def networks(self)->list[Network]: - self._database.execute("""SELECT id from networks WHERE id IN + cursor=self._database.execute("""SELECT id from networks WHERE id IN (SELECT networkid FROM networkUsers WHERE userid = ? OR subnetid = ?);""",(self.login,self.subnetid)) - networks = self._database.fetchall() + networks = cursor.fetchall() networksDB = Networks(self._database) return [networksDB.getNetworkById(res[0]) for res in networks] @@ -98,8 +96,8 @@ class User(Model): self._database.commit() def get_network_permission(self,net:Network)->int: - self._database.execute("SELECT permissions FROM networkUsers WHERE userid = ? AND networkid = ?",(self.login,net.id)) - perm = self._database.fetchone() + cursor=self._database.execute("SELECT permissions FROM networkUsers WHERE userid = ? AND networkid = ?",(self.login,net.id)) + perm = cursor.fetchone() if(not perm): return 0 else: @@ -107,28 +105,32 @@ class User(Model): @property def subnetwork_permission(self)->SubnetworkPermission: - self._database.execute("SELECT subnetperm FROM users WHERE login = ?",(self.login,)) - res = self._database.fetchone() + cursor=self._database.execute("SELECT subnetperm FROM users WHERE login = ?",(self.login,)) + res = cursor.fetchone() assert(res) return SubnetworkPermission(res[0]) -class Users(): +class Users(Model): + def __init__(self,database_handler:DatabaseHandler): - self._database = database_handler - self._database.execute("SELECT * FROM users LIMIT 1") - if(self._database.fetchone()==None): + super().__init__(database_handler) + self.initUsers() + + def initUsers(self): + cursor=self._database.execute("SELECT * FROM users LIMIT 1") + if(cursor.fetchone()==None): self.createUser('admin','admin') def getUserFromLogin(self,login:str): - self._database.execute("SELECT * FROM users WHERE login = ?",(login,)) - if(self._database.fetchone()): + cursor=self._database.execute("SELECT * FROM users WHERE login = ?",(login,)) + if(cursor.fetchone()): return User(self._database,login) else: return False def get_users_in_subnet(self,subnetid:int): - self._database.execute("SELECT login FROM users WHERE subnetid = ?",(subnetid,)) - return [self.getUserFromLogin(res[0]) for res in self._database.fetchall()] + cursor=self._database.execute("SELECT login FROM users WHERE subnetid = ?",(subnetid,)) + return [self.getUserFromLogin(res[0]) for res in cursor.fetchall()] def createUser(self,login:str,password:str,subnetid=None): if(self.getUserFromLogin(login)): @@ -139,15 +141,13 @@ class Users(): subnetid = random.getrandbits(32) subnetperm |= SubnetworkPermission.OWNER #TODO : make sure subnetid is uniq - with self._database.semaphore: - self._database.execute("INSERT INTO users VALUES(?,?,?,?)",(login,User.hashpassword(password),subnetid,int(subnetperm))) - self._database.commit() + cursor=self._database.execute("INSERT INTO users VALUES(?,?,?,?)",(login,User.hashpassword(password),subnetid,int(subnetperm))) + self._database.commit() return self.getUserFromLogin(login) def deleteUser(self,user:User): subnetwork = user.subnetid - with self._database.semaphore: - self._database.execute("DELETE FROM users WHERE login = ?",(user.login,)) - if(not self._database.execute("SELECT * FROM users WHERE subnetid = ?",(subnetwork,)).fetchall()): - self._database.execute("DELETE FROM networks WHERE subnetid = ?",(subnetwork,)) - self._database.commit() \ No newline at end of file + self._database.execute("DELETE FROM users WHERE login = ?",(user.login,)) + if(not self._database.execute("SELECT * FROM users WHERE subnetid = ?",(subnetwork,)).fetchall()): + self._database.execute("DELETE FROM networks WHERE subnetid = ?",(subnetwork,)) + self._database.commit() \ No newline at end of file diff --git a/icable_server/models/web.py b/icable_server/models/web.py index 6eb0d1c..fde44c8 100644 --- a/icable_server/models/web.py +++ b/icable_server/models/web.py @@ -6,8 +6,8 @@ class WebSessions(Model): def getSession(self,sid:int): """Return the user the session belong to or False""" - self._database.execute("SELECT sid,login FROM sessions WHERE sid = ?",(sid,)) - res = self._database.fetchone() + cursor=self._database.execute("SELECT sid,login FROM sessions WHERE sid = ?",(sid,)) + res = cursor.fetchone() if(res): return Users(self._database).getUserFromLogin(res[1]) else: @@ -16,12 +16,10 @@ class WebSessions(Model): def createSession(self,user:User): """Create a session for the user. Return the session id""" sid = getrandbits(32) - with self._database.semaphore: - self._database.execute("INSERT INTO sessions VALUES (?,?)",(sid,user.login)) - self._database.commit() + self._database.execute("INSERT INTO sessions VALUES (?,?)",(sid,user.login)) + self._database.commit() return sid def deleteSession(self,sid:int): - with self._database.semaphore: - self._database.execute("DELETE FROM sessions WHERE sid = ?",(sid,)) - self._database.commit() \ No newline at end of file + self._database.execute("DELETE FROM sessions WHERE sid = ?",(sid,)) + self._database.commit() \ No newline at end of file