Skip to content

Commit 26ebaed

Browse files
Merge pull request #32 from kevinbackhouse/fmt
hatch fmt
2 parents efbde89 + 550d9b0 commit 26ebaed

14 files changed

Lines changed: 1061 additions & 658 deletions

src/seclab_taskflows/mcp_servers/alert_results_models.py

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,13 @@
55
from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped, relationship
66
from typing import Optional
77

8+
89
class Base(DeclarativeBase):
910
pass
1011

12+
1113
class AlertResults(Base):
12-
__tablename__ = 'alert_results'
14+
__tablename__ = "alert_results"
1315

1416
canonical_id: Mapped[int] = mapped_column(primary_key=True)
1517
alert_id: Mapped[str]
@@ -22,25 +24,29 @@ class AlertResults(Base):
2224
valid: Mapped[bool] = mapped_column(nullable=False, default=True)
2325
completed: Mapped[bool] = mapped_column(nullable=False, default=False)
2426

25-
relationship('AlertFlowGraph', cascade='all, delete')
27+
relationship("AlertFlowGraph", cascade="all, delete")
2628

2729
def __repr__(self):
28-
return (f"<AlertResults(alert_id={self.alert_id}, repo={self.repo}, "
29-
f"rule={self.rule}, language={self.language}, location={self.location}, "
30-
f"result={self.result}, created_at={self.created}, valid={self.valid}, completed={self.completed})>")
30+
return (
31+
f"<AlertResults(alert_id={self.alert_id}, repo={self.repo}, "
32+
f"rule={self.rule}, language={self.language}, location={self.location}, "
33+
f"result={self.result}, created_at={self.created}, valid={self.valid}, completed={self.completed})>"
34+
)
35+
3136

3237
class AlertFlowGraph(Base):
33-
__tablename__ = 'alert_flow_graph'
38+
__tablename__ = "alert_flow_graph"
3439

3540
id: Mapped[int] = mapped_column(primary_key=True)
36-
alert_canonical_id = Column(Integer, ForeignKey('alert_results.canonical_id', ondelete='CASCADE'))
41+
alert_canonical_id = Column(Integer, ForeignKey("alert_results.canonical_id", ondelete="CASCADE"))
3742
flow_data: Mapped[str] = mapped_column(Text)
3843
repo: Mapped[str]
3944
prev: Mapped[Optional[str]]
4045
next: Mapped[Optional[str]]
4146
started: Mapped[bool] = mapped_column(nullable=False, default=False)
4247

4348
def __repr__(self):
44-
return (f"<AlertFlowGraph(alert_canonical_id={self.alert_canonical_id}, "
45-
f"flow_data={self.flow_data}, repo={self.repo}, prev={self.prev}, next={self.next}, started={self.started})>")
46-
49+
return (
50+
f"<AlertFlowGraph(alert_canonical_id={self.alert_canonical_id}, "
51+
f"flow_data={self.flow_data}, repo={self.repo}, prev={self.prev}, next={self.next}, started={self.started})>"
52+
)

src/seclab_taskflows/mcp_servers/codeql_python/codeql_sqlite_models.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,12 +5,13 @@
55
from sqlalchemy.orm import DeclarativeBase, mapped_column, Mapped
66
from typing import Optional
77

8+
89
class Base(DeclarativeBase):
910
pass
1011

1112

1213
class Source(Base):
13-
__tablename__ = 'source'
14+
__tablename__ = "source"
1415

1516
id: Mapped[int] = mapped_column(primary_key=True)
1617
repo: Mapped[str]
@@ -20,6 +21,8 @@ class Source(Base):
2021
notes: Mapped[Optional[str]] = mapped_column(Text, nullable=True)
2122

2223
def __repr__(self):
23-
return (f"<Source(id={self.id}, repo={self.repo}, "
24-
f"location={self.source_location}, line={self.line}, source_type={self.source_type}, "
25-
f"notes={self.notes})>")
24+
return (
25+
f"<Source(id={self.id}, repo={self.repo}, "
26+
f"location={self.source_location}, line={self.line}, source_type={self.source_type}, "
27+
f"notes={self.notes})>"
28+
)

src/seclab_taskflows/mcp_servers/codeql_python/mcp_server.py

Lines changed: 71 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@
66
from seclab_taskflow_agent.mcp_servers.codeql.client import run_query, _debug_log
77

88
from pydantic import Field
9-
#from mcp.server.fastmcp import FastMCP, Context
10-
from fastmcp import FastMCP # use FastMCP 2.0
9+
10+
# from mcp.server.fastmcp import FastMCP, Context
11+
from fastmcp import FastMCP # use FastMCP 2.0
1112
from pathlib import Path
1213
import os
1314
import csv
@@ -23,22 +24,20 @@
2324

2425
logging.basicConfig(
2526
level=logging.DEBUG,
26-
format='%(asctime)s - %(levelname)s - %(message)s',
27-
filename=log_file_name('mcp_codeql_python.log'),
28-
filemode='a'
27+
format="%(asctime)s - %(levelname)s - %(message)s",
28+
filename=log_file_name("mcp_codeql_python.log"),
29+
filemode="a",
2930
)
3031

31-
MEMORY = mcp_data_dir('seclab-taskflows', 'codeql', 'DATA_DIR')
32-
CODEQL_DBS_BASE_PATH = mcp_data_dir('seclab-taskflows', 'codeql', 'CODEQL_DBS_BASE_PATH')
32+
MEMORY = mcp_data_dir("seclab-taskflows", "codeql", "DATA_DIR")
33+
CODEQL_DBS_BASE_PATH = mcp_data_dir("seclab-taskflows", "codeql", "CODEQL_DBS_BASE_PATH")
3334

3435
mcp = FastMCP("CodeQL-Python")
3536

3637
# tool name -> templated query lookup for supported languages
3738
TEMPLATED_QUERY_PATHS = {
3839
# to add a language, port the templated query pack and add its definition here
39-
'python': {
40-
'remote_sources': 'queries/mcp-python/remote_sources.ql'
41-
}
40+
"python": {"remote_sources": "queries/mcp-python/remote_sources.ql"}
4241
}
4342

4443

@@ -49,9 +48,10 @@ def source_to_dict(result):
4948
"source_location": result.source_location,
5049
"line": result.line,
5150
"source_type": result.source_type,
52-
"notes": result.notes
51+
"notes": result.notes,
5352
}
5453

54+
5555
def _resolve_query_path(language: str, query: str) -> Path:
5656
global TEMPLATED_QUERY_PATHS
5757
if language not in TEMPLATED_QUERY_PATHS:
@@ -66,7 +66,7 @@ def _resolve_db_path(relative_db_path: str | Path):
6666
global CODEQL_DBS_BASE_PATH
6767
# path joins will return "/B" if "/A" / "////B" etc. as well
6868
# not windows compatible and probably needs additional hardening
69-
relative_db_path = str(relative_db_path).strip().lstrip('/')
69+
relative_db_path = str(relative_db_path).strip().lstrip("/")
7070
relative_db_path = Path(relative_db_path)
7171
absolute_path = (CODEQL_DBS_BASE_PATH / relative_db_path).resolve()
7272
if not absolute_path.is_relative_to(CODEQL_DBS_BASE_PATH.resolve()):
@@ -76,36 +76,38 @@ def _resolve_db_path(relative_db_path: str | Path):
7676
raise RuntimeError(f"Error: Database not found at {absolute_path}!")
7777
return str(absolute_path)
7878

79+
7980
# This sqlite database is specifically made for CodeQL for Python MCP.
8081
class CodeqlSqliteBackend:
8182
def __init__(self, memcache_state_dir: str):
8283
self.memcache_state_dir = memcache_state_dir
8384
if not Path(self.memcache_state_dir).exists():
84-
db_dir = 'sqlite://'
85+
db_dir = "sqlite://"
8586
else:
86-
db_dir = f'sqlite:///{self.memcache_state_dir}/codeql_sqlite.db'
87+
db_dir = f"sqlite:///{self.memcache_state_dir}/codeql_sqlite.db"
8788
self.engine = create_engine(db_dir, echo=False)
8889
Base.metadata.create_all(self.engine, tables=[Source.__table__])
8990

90-
91-
def store_new_source(self, repo, source_location, line, source_type, notes, update = False):
91+
def store_new_source(self, repo, source_location, line, source_type, notes, update=False):
9292
with Session(self.engine) as session:
93-
existing = session.query(Source).filter_by(repo = repo, source_location = source_location, line = line).first()
93+
existing = session.query(Source).filter_by(repo=repo, source_location=source_location, line=line).first()
9494
if existing:
9595
existing.notes = (existing.notes or "") + notes
9696
session.commit()
9797
return f"Updated notes for source at {source_location}, line {line} in {repo}."
9898
else:
9999
if update:
100100
return f"No source exists at repo {repo}, location {source_location}, line {line} to update."
101-
new_source = Source(repo = repo, source_location = source_location, line = line, source_type = source_type, notes = notes)
101+
new_source = Source(
102+
repo=repo, source_location=source_location, line=line, source_type=source_type, notes=notes
103+
)
102104
session.add(new_source)
103105
session.commit()
104106
return f"Added new source for {source_location} in {repo}."
105107

106108
def get_sources(self, repo):
107109
with Session(self.engine) as session:
108-
results = session.query(Source).filter_by(repo = repo).all()
110+
results = session.query(Source).filter_by(repo=repo).all()
109111
sources = [source_to_dict(source) for source in results]
110112
return sources
111113

@@ -119,8 +121,8 @@ def _csv_parse(raw):
119121
if i == 0:
120122
continue
121123
# col1 has what we care about, but offer flexibility
122-
keys = row[1].split(',')
123-
this_obj = {'description': row[0].format(*row[2:])}
124+
keys = row[1].split(",")
125+
this_obj = {"description": row[0].format(*row[2:])}
124126
for j, k in enumerate(keys):
125127
this_obj[k.strip()] = row[j + 2]
126128
results.append(this_obj)
@@ -141,27 +143,32 @@ def _run_query(query_name: str, database_path: str, language: str, template_valu
141143
except RuntimeError:
142144
return f"The query {query_name} is not supported for language: {language}"
143145
try:
144-
csv = run_query(Path(__file__).parent.resolve() /
145-
query_path,
146-
database_path,
147-
fmt='csv',
148-
template_values=template_values,
149-
log_stderr=True)
146+
csv = run_query(
147+
Path(__file__).parent.resolve() / query_path,
148+
database_path,
149+
fmt="csv",
150+
template_values=template_values,
151+
log_stderr=True,
152+
)
150153
return _csv_parse(csv)
151154
except Exception as e:
152155
return f"The query {query_name} encountered an error: {e}"
153156

157+
154158
backend = CodeqlSqliteBackend(MEMORY)
155159

160+
156161
@mcp.tool()
157-
def remote_sources(owner: str = Field(description="The owner of the GitHub repository"),
158-
repo: str = Field(description="The name of the GitHub repository"),
159-
database_path: str = Field(description="The CodeQL database path."),
160-
language: str = Field(description="The language used for the CodeQL database.")):
162+
def remote_sources(
163+
owner: str = Field(description="The owner of the GitHub repository"),
164+
repo: str = Field(description="The name of the GitHub repository"),
165+
database_path: str = Field(description="The CodeQL database path."),
166+
language: str = Field(description="The language used for the CodeQL database."),
167+
):
161168
"""List all remote sources and their locations in a CodeQL database, then store the results in a database."""
162169

163170
repo = process_repo(owner, repo)
164-
results = _run_query('remote_sources', database_path, language, {})
171+
results = _run_query("remote_sources", database_path, language, {})
165172

166173
# Check if results is an error (list of strings) or valid data (list of dicts)
167174
if isinstance(results, str):
@@ -172,53 +179,67 @@ def remote_sources(owner: str = Field(description="The owner of the GitHub repos
172179
for result in results:
173180
backend.store_new_source(
174181
repo=repo,
175-
source_location=result.get('location', ''),
176-
source_type=result.get('source', ''),
177-
line=int(result.get('line', '0')),
178-
notes=None, #result.get('description', ''),
179-
update=False
182+
source_location=result.get("location", ""),
183+
source_type=result.get("source", ""),
184+
line=int(result.get("line", "0")),
185+
notes=None, # result.get('description', ''),
186+
update=False,
180187
)
181188
stored_count += 1
182189

183190
return f"Stored {stored_count} remote sources in {repo}."
184191

192+
185193
@mcp.tool()
186-
def fetch_sources(owner: str = Field(description="The owner of the GitHub repository"),
187-
repo: str = Field(description="The name of the GitHub repository")):
194+
def fetch_sources(
195+
owner: str = Field(description="The owner of the GitHub repository"),
196+
repo: str = Field(description="The name of the GitHub repository"),
197+
):
188198
"""
189199
Fetch all sources from the repo
190200
"""
191201
repo = process_repo(owner, repo)
192202
return json.dumps(backend.get_sources(repo))
193203

204+
194205
@mcp.tool()
195-
def add_source_notes(owner: str = Field(description="The owner of the GitHub repository"),
196-
repo: str = Field(description="The name of the GitHub repository"),
197-
source_location: str = Field(description="The path to the file"),
198-
line: int = Field(description="The line number of the source"),
199-
notes: str = Field(description="The notes to append to this source")):
206+
def add_source_notes(
207+
owner: str = Field(description="The owner of the GitHub repository"),
208+
repo: str = Field(description="The name of the GitHub repository"),
209+
source_location: str = Field(description="The path to the file"),
210+
line: int = Field(description="The line number of the source"),
211+
notes: str = Field(description="The notes to append to this source"),
212+
):
200213
"""
201214
Add new notes to an existing source. The notes will be appended to any existing notes.
202215
"""
203216
repo = process_repo(owner, repo)
204-
return backend.store_new_source(repo = repo, source_location = source_location, line = line, source_type = "", notes = notes, update=True)
217+
return backend.store_new_source(
218+
repo=repo, source_location=source_location, line=line, source_type="", notes=notes, update=True
219+
)
220+
205221

206222
@mcp.tool()
207-
def clear_codeql_repo(owner: str = Field(description="The owner of the GitHub repository"),
208-
repo: str = Field(description="The name of the GitHub repository")):
223+
def clear_codeql_repo(
224+
owner: str = Field(description="The owner of the GitHub repository"),
225+
repo: str = Field(description="The name of the GitHub repository"),
226+
):
209227
"""
210228
Clear all data for a given repo from the database
211229
"""
212230
repo = process_repo(owner, repo)
213231
with Session(backend.engine) as session:
214-
deleted_sources = session.query(Source).filter_by(repo = repo).delete()
232+
deleted_sources = session.query(Source).filter_by(repo=repo).delete()
215233
session.commit()
216234
return f"Cleared {deleted_sources} sources from repo {repo}."
217235

236+
218237
if __name__ == "__main__":
219238
# Check if codeql/python-all pack is installed, if not install it
220-
if not os.path.isdir('/.codeql/packages/codeql/python-all'):
221-
pack_path = importlib.resources.files('seclab_taskflows.mcp_servers.codeql_python.queries').joinpath('mcp-python')
239+
if not os.path.isdir("/.codeql/packages/codeql/python-all"):
240+
pack_path = importlib.resources.files("seclab_taskflows.mcp_servers.codeql_python.queries").joinpath(
241+
"mcp-python"
242+
)
222243
print(f"Installing CodeQL pack from {pack_path}")
223244
subprocess.run(["codeql", "pack", "install", pack_path])
224245
mcp.run(show_banner=False, transport="http", host="127.0.0.1", port=9998)

0 commit comments

Comments
 (0)