|
12 | 12 | from pydantic import BaseModel, Field |
13 | 13 | from pydantic.generics import GenericModel |
14 | 14 | from sqlalchemy import engine_from_config, text |
| 15 | +from sqlalchemy.engine import Engine |
15 | 16 | from sqlalchemy.ext.declarative import DeferredReflection |
16 | 17 | from sqlalchemy.orm import Query as LegacyQuery |
17 | 18 | from sqlalchemy.orm.session import Session as SqlaSession |
|
51 | 52 | _Session = sessionmaker() |
52 | 53 |
|
53 | 54 |
|
| 55 | +def new_engine(*, envvar_prefix: str = None) -> Engine: |
| 56 | + envvar_prefix = envvar_prefix if envvar_prefix else "sqlalchemy_" |
| 57 | + lowercase_environ = { |
| 58 | + k.lower(): v for k, v in os.environ.items() if k.lower() != "sqlalchemy_warn_20" |
| 59 | + } |
| 60 | + return engine_from_config(lowercase_environ, prefix=envvar_prefix) |
| 61 | + |
| 62 | + |
| 63 | +def is_async_dialect(engine): |
| 64 | + return engine.dialect.is_async if hasattr(engine.dialect, "is_async") else False |
| 65 | + |
| 66 | + |
54 | 67 | def setup(app: FastAPI): |
55 | | - app.add_event_handler("startup", startup) |
56 | | - app.middleware("http")(add_session_to_request) |
| 68 | + engine = new_engine() |
| 69 | + |
| 70 | + if not is_async_dialect(engine): |
| 71 | + app.add_event_handler("startup", startup) |
| 72 | + app.middleware("http")(add_session_to_request) |
57 | 73 |
|
58 | | - async_sqlalchemy_url = os.getenv("async_sqlalchemy_url") |
59 | | - if async_sqlalchemy_url: |
| 74 | + has_async_config = "async_sqlalchemy_url" in os.environ or is_async_dialect(engine) |
| 75 | + if has_async_config: |
60 | 76 | assert asyncio_support, asyncio_support_err |
61 | 77 | app.add_event_handler("startup", asyncio_support.startup) |
62 | 78 | app.middleware("http")(asyncio_support.add_session_to_request) |
63 | 79 |
|
64 | 80 |
|
65 | 81 | def startup(): |
66 | | - lowercase_environ = { |
67 | | - k.lower(): v for k, v in os.environ.items() if k.lower() != "sqlalchemy_warn_20" |
68 | | - } |
69 | | - engine = engine_from_config(lowercase_environ, prefix="sqlalchemy_") |
| 82 | + engine = new_engine() |
70 | 83 | aws_rds_iam_support.setup(engine.engine) |
71 | 84 |
|
72 | | - Base.metadata.bind = engine |
73 | | - Base.prepare(engine) |
74 | | - _Session.configure(bind=engine) |
75 | | - |
76 | 85 | # Fail early: |
77 | 86 | try: |
78 | | - with open_session() as session: |
79 | | - session.execute(text("select 'OK'")) |
| 87 | + with engine.connect() as connection: |
| 88 | + connection.execute(text("select 'OK'")) |
80 | 89 | except Exception: |
81 | 90 | logger.critical( |
82 | 91 | "Fail querying db: is sqlalchemy_url envvar correctly configured?" |
83 | 92 | ) |
84 | 93 | raise |
85 | 94 |
|
| 95 | + Base.prepare(engine) |
| 96 | + _Session.configure(bind=engine) |
86 | 97 | logger.info("startup", engine=engine) |
87 | 98 |
|
88 | 99 |
|
|
0 commit comments