|
import sqlite3 |
|
from sqlite3 import Error |
|
from pydantic import BaseModel |
|
|
|
class User(BaseModel): |
|
username: str |
|
email: str |
|
hashed_password: str |
|
|
|
|
|
def create_connection(db_file): |
|
""" create a database connection to a SQLite database """ |
|
conn = None |
|
try: |
|
conn = sqlite3.connect(db_file) |
|
return conn |
|
except Error as e: |
|
print(e) |
|
return conn |
|
|
|
|
|
def create_table(conn): |
|
""" create a table in the database """ |
|
sql = """ CREATE TABLE IF NOT EXISTS users ( |
|
id integer PRIMARY KEY, |
|
username text NOT NULL UNIQUE, |
|
email text NOT NULL UNIQUE, |
|
hashed_password text NOT NULL |
|
); """ |
|
try: |
|
cursor = conn.cursor() |
|
cursor.execute(sql) |
|
except Error as e: |
|
print(e) |
|
|
|
|
|
def create_db_and_table(): |
|
database = "users.db" |
|
conn = create_connection(database) |
|
if conn is not None: |
|
create_table(conn) |
|
conn.close() |
|
else: |
|
print("Error! cannot create the database connection.") |
|
|
|
|
|
def insert_user(user_data: dict): |
|
conn = create_connection('users.db') |
|
cur = conn.cursor() |
|
sql = """ INSERT INTO users(username,email,hashed_password) |
|
VALUES(?,?,?) """ |
|
cur.execute(sql, (user_data['username'], user_data['email'], user_data['hashed_password'])) |
|
conn.commit() |
|
last_row_id = cur.lastrowid |
|
conn.close() |
|
if last_row_id: |
|
|
|
return get_user(user_data['username']) |
|
else: |
|
return None |
|
|
|
|
|
def get_user(username: str): |
|
conn = create_connection('users.db') |
|
cur = conn.cursor() |
|
cur.execute("SELECT * FROM users WHERE username=?", (username,)) |
|
user = cur.fetchone() |
|
conn.close() |
|
if user: |
|
return User(username=user[1], email=user[2], hashed_password=user[3]) |
|
else: |
|
return None |
|
|
|
|
|
def delete_user(username: str): |
|
conn = create_connection('users.db') |
|
cur = conn.cursor() |
|
cur.execute("DELETE FROM users WHERE username=?", (username,)) |
|
conn.commit() |
|
rows_affected = cur.rowcount |
|
conn.close() |
|
return rows_affected > 0 |
|
|
|
|
|
def update_user(username: str, updated_data: dict): |
|
conn = create_connection('users.db') |
|
cur = conn.cursor() |
|
sql = """ UPDATE users |
|
SET email = ?, |
|
hashed_password = ? |
|
WHERE username = ?""" |
|
cur.execute(sql, (updated_data['email'], updated_data['hashed_password'], username)) |
|
conn.commit() |
|
rows_affected = cur.rowcount |
|
conn.close() |
|
if rows_affected > 0: |
|
|
|
return get_user(username) |
|
else: |
|
return None |
|
|
|
def get_all_users(): |
|
conn = create_connection('users.db') |
|
cur = conn.cursor() |
|
cur.execute("SELECT * FROM users") |
|
users = cur.fetchall() |
|
conn.close() |
|
return [User(username=user[1], email=user[2], hashed_password=user[3]) for user in users] |