|
| 1 | +import os |
| 2 | +from unittest.mock import patch |
| 3 | + |
| 4 | +from pytest import fixture |
| 5 | +from sqlalchemy import create_engine |
| 6 | + |
| 7 | + |
| 8 | +@fixture(scope="session") |
| 9 | +def db_url(): |
| 10 | + """Default db url used by depending fixtures. |
| 11 | +
|
| 12 | + When CI key is set in environment variables, it uses `postgres` as host name: |
| 13 | + postgresql://postgres@posgres/postgres |
| 14 | +
|
| 15 | + Else, host used is `localhost`: postgresql://postgres@localhost/postgres |
| 16 | + """ |
| 17 | + host = "postgres" if "CI" in os.environ else "localhost" |
| 18 | + return f"postgresql://postgres@{host}/postgres" |
| 19 | + |
| 20 | + |
| 21 | +@fixture(scope="session") |
| 22 | +def sqla_connection(db_url): |
| 23 | + engine = create_engine(db_url) |
| 24 | + connection = engine.connect() |
| 25 | + yield connection |
| 26 | + connection.close() |
| 27 | + |
| 28 | + |
| 29 | +@fixture |
| 30 | +def sqla_modules(): |
| 31 | + raise Exception( |
| 32 | + "sqla_modules fixture is not defined. Define a sqla_modules fixture which " |
| 33 | + "imports all modules with sqla entities deriving from fastapi_sqla.Base ." |
| 34 | + ) |
| 35 | + |
| 36 | + |
| 37 | +@fixture(autouse=True) |
| 38 | +def sqla_reflection(sqla_modules, sqla_connection, db_url): |
| 39 | + import fastapi_sqla |
| 40 | + |
| 41 | + fastapi_sqla.Base.metadata.bind = sqla_connection |
| 42 | + fastapi_sqla.Base.prepare(sqla_connection) |
| 43 | + |
| 44 | + |
| 45 | +@fixture(autouse=True) |
| 46 | +def patch_sessionmaker(db_url, sqla_connection, sqla_transaction): |
| 47 | + """So that all DB operations are never written to db for real.""" |
| 48 | + with patch("fastapi_sqla.engine_from_config") as engine_from_config: |
| 49 | + engine_from_config.return_value = sqla_connection |
| 50 | + yield engine_from_config |
| 51 | + |
| 52 | + |
| 53 | +@fixture |
| 54 | +def sqla_transaction(sqla_connection): |
| 55 | + transaction = sqla_connection.begin() |
| 56 | + yield transaction |
| 57 | + transaction.rollback() |
| 58 | + |
| 59 | + |
| 60 | +@fixture |
| 61 | +def session(sqla_transaction, sqla_connection): |
| 62 | + """Sqla session to use when creating db fixtures. |
| 63 | +
|
| 64 | + While it does not write any record in DB, the application will still be able to access any record |
| 65 | + committed with that session. |
| 66 | + """ |
| 67 | + import fastapi_sqla |
| 68 | + |
| 69 | + session = fastapi_sqla._Session(bind=sqla_connection) |
| 70 | + yield session |
| 71 | + session.close() |
0 commit comments