|
1 | 1 | import asyncio |
| 2 | +import math |
2 | 3 | import os |
3 | 4 | from contextlib import contextmanager |
| 5 | +from typing import Callable, Generic, List, TypeVar |
4 | 6 |
|
5 | 7 | import structlog |
6 | | -from fastapi import FastAPI, Request |
| 8 | +from fastapi import Depends, FastAPI, Query, Request |
7 | 9 | from fastapi.concurrency import contextmanager_in_threadpool |
| 10 | +from pydantic import BaseModel, Field |
| 11 | +from pydantic.generics import GenericModel |
8 | 12 | from sqlalchemy import engine_from_config |
9 | 13 | from sqlalchemy.ext.declarative import DeferredReflection, declarative_base |
| 14 | +from sqlalchemy.orm import Query as DbQuery |
10 | 15 | from sqlalchemy.orm.session import Session, sessionmaker |
11 | 16 |
|
12 | 17 | __all__ = ["Base", "setup", "with_session"] |
@@ -108,3 +113,76 @@ def get_users(session: sqla.Session = Depends(sqla.new_session)): |
108 | 113 | await loop.run_in_executor(None, session.rollback) |
109 | 114 |
|
110 | 115 | return response |
| 116 | + |
| 117 | + |
| 118 | +T = TypeVar("T") |
| 119 | + |
| 120 | + |
| 121 | +class Item(GenericModel, Generic[T]): |
| 122 | + """Item container.""" |
| 123 | + |
| 124 | + data: T |
| 125 | + |
| 126 | + |
| 127 | +class Collection(GenericModel, Generic[T]): |
| 128 | + """Collection container.""" |
| 129 | + |
| 130 | + data: List[T] |
| 131 | + |
| 132 | + |
| 133 | +class Meta(BaseModel): |
| 134 | + """Meta information on current page and collection""" |
| 135 | + |
| 136 | + offset: int = Field(..., description="Current page offset") |
| 137 | + total_items: int = Field(..., description="Total number of items in the collection") |
| 138 | + total_pages: int = Field(..., description="Total number of pages in the collection") |
| 139 | + page_number: int = Field(..., description="Current page number. Starts at 1.") |
| 140 | + |
| 141 | + |
| 142 | +class Paginated(Collection, Generic[T]): |
| 143 | + """Paginated collection with information on current page and total items in meta.""" |
| 144 | + |
| 145 | + meta: Meta |
| 146 | + |
| 147 | + |
| 148 | +def _query_count(session: Session, query: DbQuery) -> int: |
| 149 | + """Default function used to count items returned by a query. |
| 150 | +
|
| 151 | + Default Query.count is slower than a manually written query could be: It runs the |
| 152 | + query in a subquery, and count the number of elements returned: |
| 153 | +
|
| 154 | + See https://gist.github.com/hest/8798884 |
| 155 | + """ |
| 156 | + return query.count() |
| 157 | + |
| 158 | + |
| 159 | +def Pagination( |
| 160 | + min_page_size: int = 10, |
| 161 | + max_page_size: int = 100, |
| 162 | + query_count: Callable[[Session, DbQuery], int] = _query_count, |
| 163 | +) -> Callable[[Session, int, int], Callable[[DbQuery], Paginated[T]]]: |
| 164 | + def dependency( |
| 165 | + session: Session = Depends(with_session), |
| 166 | + offset: int = Query(0, ge=0), |
| 167 | + limit: int = Query(min_page_size, ge=1, le=max_page_size), |
| 168 | + ) -> Callable[[DbQuery], Paginated[T]]: |
| 169 | + def paginated_result(query: DbQuery) -> Paginated[T]: |
| 170 | + total_items = query_count(session, query) |
| 171 | + total_pages = math.ceil(total_items / limit) |
| 172 | + page_number = offset / limit + 1 |
| 173 | + return Paginated[T]( |
| 174 | + data=query.offset(offset).limit(limit).all(), |
| 175 | + meta={ |
| 176 | + "offset": offset, |
| 177 | + "total_items": total_items, |
| 178 | + "total_pages": total_pages, |
| 179 | + "page_number": page_number, |
| 180 | + }, |
| 181 | + ) |
| 182 | + |
| 183 | + return paginated_result |
| 184 | + |
| 185 | + return dependency |
| 186 | + |
| 187 | + |
| 188 | +with_pagination = Pagination() |
0 commit comments