diff --git a/libsql/src/sync.rs b/libsql/src/sync.rs index a1d6b742f9..a8fd090df5 100644 --- a/libsql/src/sync.rs +++ b/libsql/src/sync.rs @@ -60,6 +60,10 @@ pub enum SyncError { RedirectHeader(http::header::ToStrError), #[error("redirect response with no location header")] NoRedirectLocationHeader, + #[error("failed to pull db export: status={0}, error={1}")] + PullDb(StatusCode, String), + #[error("server returned a lower generation than local: local={0}, remote={1}")] + InvalidLocalGeneration(u32, u32), } impl SyncError { @@ -86,6 +90,11 @@ pub enum PullResult { EndOfGeneration { max_generation: u32 }, } +#[derive(serde::Deserialize)] +struct InfoResult { + current_generation: u32, +} + pub struct SyncContext { db_path: String, client: hyper::Client, @@ -97,6 +106,9 @@ pub struct SyncContext { durable_generation: u32, /// Represents the max_frame_no from the server. durable_frame_num: u32, + /// whenever sync is called very first time, we will call the remote server + /// to get the generation information and sync the db file if needed + initial_server_sync: bool, } impl SyncContext { @@ -123,8 +135,9 @@ impl SyncContext { max_retries: DEFAULT_MAX_RETRIES, push_batch_size: DEFAULT_PUSH_BATCH_SIZE, client, - durable_generation: 1, + durable_generation: 0, durable_frame_num: 0, + initial_server_sync: false, }; if let Err(e) = me.read_metadata().await { @@ -173,7 +186,7 @@ impl SyncContext { frame_no, frame_no + frames_count ); - tracing::debug!("pushing frame"); + tracing::debug!("pushing frame(frame_no={}, count={}, generation={})", frame_no, frames_count, generation); let result = self.push_with_retry(uri, frames, self.max_retries).await?; @@ -458,6 +471,105 @@ impl SyncContext { Ok(()) } + + /// get_remote_info calls the remote server to get the current generation information. + async fn get_remote_info(&self) -> Result { + let uri = format!("{}/info", self.sync_url); + let mut req = http::Request::builder().method("GET").uri(&uri); + + if let Some(auth_token) = &self.auth_token { + req = req.header("Authorization", auth_token); + } + + let req = req.body(Body::empty()).expect("valid request"); + + let res = self + .client + .request(req) + .await + .map_err(SyncError::HttpDispatch)?; + + if !res.status().is_success() { + let status = res.status(); + let body = hyper::body::to_bytes(res.into_body()) + .await + .map_err(SyncError::HttpBody)?; + return Err( + SyncError::PullDb(status, String::from_utf8_lossy(&body).to_string()).into(), + ); + } + + let body = hyper::body::to_bytes(res.into_body()) + .await + .map_err(SyncError::HttpBody)?; + + let info = serde_json::from_slice(&body).map_err(SyncError::JsonDecode)?; + + Ok(info) + } + + async fn sync_db_if_needed(&mut self, generation: u32) -> Result<()> { + // we will get the export file only if the remote generation is different from the one we have + if generation == self.durable_generation { + return Ok(()); + } + // somehow we are ahead of the remote in generations. following should not happen because + // we checkpoint only if the remote server tells us to do so. + if self.durable_generation > generation { + tracing::error!( + "server returned a lower generation than what we have: sent={}, got={}", + self.durable_generation, + generation + ); + return Err( + SyncError::InvalidLocalGeneration(self.durable_generation, generation).into(), + ); + } + tracing::debug!( + "syncing db file from remote server, generation={}", + generation + ); + self.sync_db(generation).await + } + + /// sync_db will download the db file from the remote server and replace the local file. + async fn sync_db(&mut self, generation: u32) -> Result<()> { + let uri = format!("{}/export/{}", self.sync_url, generation); + let mut req = http::Request::builder().method("GET").uri(&uri); + + if let Some(auth_token) = &self.auth_token { + req = req.header("Authorization", auth_token); + } + + let req = req.body(Body::empty()).expect("valid request"); + + let res = self + .client + .request(req) + .await + .map_err(SyncError::HttpDispatch)?; + + if !res.status().is_success() { + let status = res.status(); + let body = hyper::body::to_bytes(res.into_body()) + .await + .map_err(SyncError::HttpBody)?; + return Err( + SyncError::PullFrame(status, String::from_utf8_lossy(&body).to_string()).into(), + ); + } + + // todo: do streaming write to the disk + let bytes = hyper::body::to_bytes(res.into_body()) + .await + .map_err(SyncError::HttpBody)?; + + atomic_write(&self.db_path, &bytes).await?; + self.durable_generation = generation; + self.durable_frame_num = 0; + self.write_metadata().await?; + Ok(()) + } } #[derive(serde::Serialize, serde::Deserialize, Debug)] @@ -555,6 +667,22 @@ pub async fn sync_offline( Err(e) => Err(e), } } else { + // todo: we are checking with the remote server only during initialisation. ideally, + // we should check everytime we try to sync with the remote server. However, we need to close + // all the ongoing connections since we replace `.db` file and remove the `.db-wal` file + if !sync_ctx.initial_server_sync { + // sync is being called first time. so we will call remote, get the generation information + // if we are lagging behind, then we will call the export API and get to the latest + // generation directly. + let info = sync_ctx.get_remote_info().await?; + sync_ctx + .sync_db_if_needed(info.current_generation) + .await?; + // when sync_ctx is initialised, we set durable_generation to 0. however, once + // sync_db is called, it should be > 0. + assert!(sync_ctx.durable_generation > 0, "generation should be > 0"); + sync_ctx.initial_server_sync = true; + } try_pull(sync_ctx, conn).await } .or_else(|err| { diff --git a/libsql/src/sync/test.rs b/libsql/src/sync/test.rs index 38af3902fa..2417fa8158 100644 --- a/libsql/src/sync/test.rs +++ b/libsql/src/sync/test.rs @@ -129,7 +129,7 @@ async fn test_sync_context_corrupted_metadata() { // Verify that the context was reset to default values assert_eq!(sync_ctx.durable_frame_num(), 0); - assert_eq!(sync_ctx.durable_generation(), 1); + assert_eq!(sync_ctx.durable_generation(), 0); } #[tokio::test]