Skip to content

Commit ffdc89a

Browse files
authored
Merge pull request #8 from SecrinLabs/feature/streaming-token
token streaming support
2 parents 8435b24 + 3d64ba9 commit ffdc89a

5 files changed

Lines changed: 196 additions & 3 deletions

File tree

apps/api/routes/v1/ask.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import logging
2+
import json
23
from typing import Any
34

45
from fastapi import APIRouter, HTTPException, status
6+
from fastapi.responses import StreamingResponse
57

68
from apps.api.routes.v1.schemas.qa import (
79
QARequest,
@@ -15,6 +17,7 @@
1517
from packages.memory.services.issue_analysis import IssueAnalyzer
1618
from packages.database.graph.graph import neo4j_client
1719
from packages.config import Settings
20+
from packages.config.feature_flags import is_feature_enabled, FeatureFlag
1821

1922
router = APIRouter(prefix="/ask", tags=["Question Answering"])
2023
settings = Settings()
@@ -107,6 +110,12 @@ async def analyze_issue(request: IssueRequest):
107110
try:
108111
logger.info(f"Analyzing issue: {request.title}")
109112

113+
if is_feature_enabled(FeatureFlag.ENABLE_TOKEN_STREAMING):
114+
return StreamingResponse(
115+
_stream_issue_analysis(request.title, request.body),
116+
media_type="text/event-stream"
117+
)
118+
110119
result = issue_analyzer.analyze_issue(request.title, request.body)
111120

112121
if "error" in result:
@@ -131,3 +140,16 @@ async def analyze_issue(request: IssueRequest):
131140
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
132141
detail="An unexpected error occurred. Please try again later.",
133142
)
143+
144+
145+
def _stream_issue_analysis(title: str, body: str):
146+
"""Generator for streaming issue analysis."""
147+
try:
148+
for chunk in issue_analyzer.analyze_issue_stream(title, body):
149+
yield f"data: {json.dumps(chunk)}\n\n"
150+
except Exception as e:
151+
logger.exception("Error during streaming issue analysis")
152+
yield f"data: {json.dumps({'error': str(e)})}\n\n"
153+
finally:
154+
yield "data: [DONE]\n\n"
155+

packages/config/feature_flags.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ class FeatureFlag(str, Enum):
4040
ENABLE_AUTO_INDEXING = "enable_auto_indexing"
4141
ENABLE_SMART_RETRY = "enable_smart_retry"
4242
ENABLE_MULTIMODAL_EMBEDDINGS = "enable_multimodal_embeddings"
43+
ENABLE_TOKEN_STREAMING = "enable_token_streaming"
4344

4445

4546
class FeatureFlagConfig(BaseModel):
@@ -150,6 +151,10 @@ def _initialize_defaults(self) -> Dict[FeatureFlag, FeatureFlagConfig]:
150151
enabled=False,
151152
environments=["development"]
152153
),
154+
FeatureFlag.ENABLE_TOKEN_STREAMING: FeatureFlagConfig(
155+
enabled=False,
156+
environments=["development", "staging", "production"]
157+
),
153158
}
154159

155160
def is_enabled(self, flag: FeatureFlag) -> bool:

packages/memory/llm/base.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
"""
55

66
from abc import ABC, abstractmethod
7-
from typing import List, Any, Optional
7+
from typing import List, Any, Optional, Iterator
88

99

1010
class BaseLLMProvider(ABC):
@@ -45,6 +45,26 @@ def generate_answer(
4545
Exception: If generation fails
4646
"""
4747
pass
48+
49+
def stream_answer(
50+
self,
51+
question: str,
52+
context_items: List[Any],
53+
search_type: str
54+
) -> Iterator[str]:
55+
"""
56+
Stream an answer to the question using provided context.
57+
58+
Args:
59+
question: User's question
60+
context_items: List of search results to use as context
61+
search_type: Type of search performed (vector/hybrid)
62+
63+
Returns:
64+
Iterator yielding generated answer chunks
65+
"""
66+
prompt = self._build_prompt(question, context_items, search_type)
67+
return self.stream_text(prompt)
4868

4969
@abstractmethod
5070
def generate_text(self, prompt: str, system_prompt: Optional[str] = None) -> str:
@@ -59,6 +79,20 @@ def generate_text(self, prompt: str, system_prompt: Optional[str] = None) -> str
5979
Generated text
6080
"""
6181
pass
82+
83+
@abstractmethod
84+
def stream_text(self, prompt: str, system_prompt: Optional[str] = None) -> Iterator[str]:
85+
"""
86+
Stream text from a raw prompt.
87+
88+
Args:
89+
prompt: The prompt to send to the LLM
90+
system_prompt: Optional system prompt
91+
92+
Returns:
93+
Iterator yielding generated text chunks
94+
"""
95+
pass
6296

6397
@abstractmethod
6498
def get_provider_name(self) -> str:

packages/memory/llm/providers/ollama.py

Lines changed: 54 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
"""
44

55
import requests
6-
from typing import List, Any, Optional
6+
import json
7+
from typing import List, Any, Optional, Iterator
78
from packages.memory.llm.base import BaseLLMProvider
89
from packages.config.settings import Settings
910

@@ -118,6 +119,58 @@ def generate_text(self, prompt: str, system_prompt: Optional[str] = None) -> str
118119
)
119120
except requests.exceptions.RequestException as e:
120121
raise Exception(f"Ollama API error: {e}")
122+
123+
def stream_text(self, prompt: str, system_prompt: Optional[str] = None) -> Iterator[str]:
124+
"""
125+
Stream text using Ollama.
126+
127+
Args:
128+
prompt: The prompt to send
129+
system_prompt: Optional system prompt
130+
131+
Returns:
132+
Iterator yielding generated text chunks
133+
"""
134+
try:
135+
payload = {
136+
"model": self.model,
137+
"prompt": prompt,
138+
"stream": True,
139+
"options": {
140+
"temperature": self.temperature,
141+
"num_predict": self.max_tokens
142+
}
143+
}
144+
145+
if system_prompt:
146+
payload["system"] = system_prompt
147+
148+
response = requests.post(
149+
f"{self.base_url}/api/generate",
150+
json=payload,
151+
timeout=self.timeout,
152+
stream=True
153+
)
154+
response.raise_for_status()
155+
156+
for line in response.iter_lines():
157+
if line:
158+
try:
159+
json_response = json.loads(line)
160+
if "response" in json_response:
161+
yield json_response["response"]
162+
if json_response.get("done", False):
163+
break
164+
except json.JSONDecodeError:
165+
continue
166+
167+
except requests.exceptions.Timeout:
168+
raise TimeoutError(
169+
f"Ollama request timed out after {self.timeout}s. "
170+
"Try increasing LLM_TIMEOUT or reducing context size."
171+
)
172+
except requests.exceptions.RequestException as e:
173+
raise Exception(f"Ollama API error: {e}")
121174

122175
def get_provider_name(self) -> str:
123176
"""Return provider name."""

packages/memory/services/issue_analysis.py

Lines changed: 80 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from typing import Dict, Any, List, Optional
1+
from typing import Dict, Any, List, Optional, Iterator
22
import logging
33
from packages.memory.services.graph_service import GraphService
44
from packages.memory.llm import BaseLLMProvider
@@ -104,6 +104,85 @@ def analyze_issue(self, title: str, body: str) -> Dict[str, Any]:
104104
"context_used": [self._format_item_summary(item) for item in all_context]
105105
}
106106

107+
def analyze_issue_stream(self, title: str, body: str) -> Iterator[Dict[str, Any]]:
108+
"""
109+
Analyze an issue and stream the report.
110+
111+
Args:
112+
title: Issue title
113+
body: Issue body/description
114+
115+
Returns:
116+
Iterator yielding chunks of the report or context info
117+
"""
118+
issue_text = f"{title}\n\n{body}"
119+
logger.info(f"Analyzing issue (streaming): {title}")
120+
121+
# 1. Search for relevant code (Functions, Classes, Files)
122+
code_context = self.graph_service.hybrid_search(
123+
query_text=issue_text,
124+
node_type="Function",
125+
limit=5
126+
)
127+
128+
# Also search for Files directly
129+
file_context = self.graph_service.hybrid_search(
130+
query_text=issue_text,
131+
node_type="File",
132+
limit=3
133+
)
134+
135+
# 2. Search for relevant history (Commits)
136+
commit_context = self.graph_service.hybrid_search(
137+
query_text=issue_text,
138+
node_type="Commit",
139+
limit=5
140+
)
141+
142+
# Combine context
143+
all_context = code_context + file_context + commit_context
144+
145+
if not all_context:
146+
yield {"error": "No relevant context found in the knowledge graph."}
147+
return
148+
149+
# Yield context info first
150+
yield {
151+
"context_used": [self._format_item_summary(item) for item in all_context]
152+
}
153+
154+
# 3. Generate Report using LLM
155+
system_prompt = """
156+
You are an expert software engineer and debugger.
157+
You are given a GitHub issue description and a set of relevant code snippets and commit history from the project's Knowledge Graph.
158+
159+
Your task is to analyze the issue and provide a detailed report containing:
160+
1. **Root Cause Analysis**: What is likely causing the issue based on the code and history?
161+
2. **Affected Areas**: Which files, classes, or functions are involved?
162+
3. **Suggested Fix**: How can this be fixed? Provide code snippets if possible.
163+
4. **Relevant History**: Are there recent commits that might have introduced this?
164+
165+
Be specific. Reference the filenames and function names provided in the context.
166+
"""
167+
168+
# Format context for LLM
169+
context_str = self._format_context(all_context)
170+
171+
prompt = f"""
172+
ISSUE TITLE: {title}
173+
174+
ISSUE BODY:
175+
{body}
176+
177+
RELEVANT CONTEXT FROM KNOWLEDGE GRAPH:
178+
{context_str}
179+
180+
Please provide your analysis report.
181+
"""
182+
183+
for chunk in self.llm_provider.stream_text(prompt=prompt, system_prompt=system_prompt):
184+
yield {"chunk": chunk}
185+
107186
def _format_context(self, items: List[Any]) -> str:
108187
output = []
109188
for item in items:

0 commit comments

Comments
 (0)