40 lines
1.0 KiB
Python
40 lines
1.0 KiB
Python
from collections.abc import Generator
|
|
|
|
from sqlalchemy import create_engine
|
|
from sqlalchemy.engine import Engine
|
|
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
|
|
|
|
|
Base = declarative_base()
|
|
|
|
_ENGINE: Engine | None = None
|
|
_SessionLocal: sessionmaker[Session] | None = None
|
|
|
|
|
|
def init_engine(db_url: str) -> Engine:
|
|
global _ENGINE, _SessionLocal
|
|
|
|
connect_args = {}
|
|
if db_url.startswith("sqlite"):
|
|
connect_args["check_same_thread"] = False
|
|
|
|
_ENGINE = create_engine(db_url, future=True, connect_args=connect_args)
|
|
_SessionLocal = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False, future=True)
|
|
return _ENGINE
|
|
|
|
|
|
def get_engine() -> Engine:
|
|
if _ENGINE is None:
|
|
raise RuntimeError("Database engine is not initialized")
|
|
return _ENGINE
|
|
|
|
|
|
def get_db() -> Generator[Session, None, None]:
|
|
if _SessionLocal is None:
|
|
raise RuntimeError("SessionLocal is not initialized")
|
|
db = _SessionLocal()
|
|
try:
|
|
yield db
|
|
finally:
|
|
db.close()
|