This commit is contained in:
@@ -0,0 +1,39 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.orm import Session, declarative_base, sessionmaker
|
||||
|
||||
|
||||
Base = declarative_base()
|
||||
|
||||
_ENGINE: Engine | None = None
|
||||
_SessionLocal: sessionmaker[Session] | None = None
|
||||
|
||||
|
||||
def init_engine(db_url: str) -> Engine:
|
||||
global _ENGINE, _SessionLocal
|
||||
|
||||
connect_args = {}
|
||||
if db_url.startswith("sqlite"):
|
||||
connect_args["check_same_thread"] = False
|
||||
|
||||
_ENGINE = create_engine(db_url, future=True, connect_args=connect_args)
|
||||
_SessionLocal = sessionmaker(bind=_ENGINE, autoflush=False, autocommit=False, future=True)
|
||||
return _ENGINE
|
||||
|
||||
|
||||
def get_engine() -> Engine:
|
||||
if _ENGINE is None:
|
||||
raise RuntimeError("Database engine is not initialized")
|
||||
return _ENGINE
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
if _SessionLocal is None:
|
||||
raise RuntimeError("SessionLocal is not initialized")
|
||||
db = _SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
Reference in New Issue
Block a user