Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ generate-plugin-header plugin:
rm -rf target/tmp



generate-selene-core-headers:
cbindgen \
--config selene-core/cbindgen.toml \
Expand Down
23 changes: 23 additions & 0 deletions selene-core/hatch_build.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,32 @@
import json
from hatchling.builders.hooks.plugin.interface import BuildHookInterface
from pathlib import Path


class SeleneCoreBuildHook(BuildHookInterface):
def initialize(self, version: str, build_data: dict) -> None:
# Generate and write the Trace JSON schema into _dist
# as selene_core isn't yet available we need to import selene_core/trace.py manually
import importlib.util

trace_module_path = Path("python/selene_core/trace.py")
spec = importlib.util.spec_from_file_location(
"selene_core.trace", trace_module_path
)
if spec is None or spec.loader is None:
raise RuntimeError(
f"Unable to load module spec for selene_core.trace from {trace_module_path}"
)
trace_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(trace_module)
schema_path = Path(
"python/selene_core/_dist/share/selene-core/schemas/trace.json"
)
schema_path.parent.mkdir(parents=True, exist_ok=True)
schema_path.write_text(
json.dumps(trace_module.Trace.model_json_schema(), indent=2)
)

artifacts = []
dist_dir = Path("python/selene_core/_dist")
for artifact in dist_dir.rglob("*"):
Expand Down
4 changes: 3 additions & 1 deletion selene-core/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ dependencies = [
"llvmlite==0.45.1; sys_platform == 'darwin' and platform_machine == 'x86_64'",
"llvmlite~=0.47; sys_platform != 'darwin' or platform_machine != 'x86_64'",
"networkx>=2.6,<4",
"pydantic>=2.12.5",
"pydot>=4.0.0",
"pyyaml~=6.0",
"qir-qis>=0.1.2,<0.1.4",
Expand All @@ -25,7 +26,7 @@ homepage = "https://github.com/quantinuum/selene/selene-core"
repository = "https://github.com/quantinuum/selene/selene-core"

[build-system]
requires = ["hatchling"]
requires = ["hatchling", "pydantic>=2.12.5"]
build-backend = "hatchling.build"

[tool.hatch.build.targets.wheel]
Expand All @@ -40,6 +41,7 @@ cache-keys = [
{ file = "examples/**/*.rs" },
{ file = "examples/**/Cargo.lock" },
{ file = "examples/**/Cargo.toml" },
{ file = "python/selene_core/trace.py" },
{ file = "c/include/selene/*.h" },
{ file = "Cargo.toml" },
]
Expand Down
135 changes: 135 additions & 0 deletions selene-core/python/selene_core/trace.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from typing import Annotated, Literal, Union, Callable
from pydantic import BaseModel, ConfigDict, Field


class PredicateResult(BaseModel):
predicate: str
result: bool


class UserProgramSource(BaseModel):
kind: Literal["UserProgram"] = "UserProgram"
index: int


class RuntimeSource(BaseModel):
kind: Literal["Runtime"] = "Runtime"
start_time: int
end_time: int


# in future we hope to add ErrorModelSource, so we can keep track of noise.
# this will require a breaking change, as the error model doesn't presently
# command the simulator via selene's internals (where logging can be done),
# but instead controls a simulator directly.


class AbstractEvent(BaseModel):
model_config = ConfigDict(
use_enum_values=True,
extra="ignore",
ser_json_bytes="base64",
val_json_bytes="base64",
)


class GateEvent(AbstractEvent):
kind: Literal["Gate"] = "Gate"
qubits: list[int] = Field(default_factory=list)
gate_name: str
params: list[float | int | bool] = Field(default_factory=list)
predicates: list[PredicateResult] = Field(default_factory=list)


class MeasurementEvent(AbstractEvent):
kind: Literal["Measurement"] = "Measurement"
qubit: int


class ResetEvent(AbstractEvent):
kind: Literal["Reset"] = "Reset"
qubit: int


class OpaquePayload(AbstractEvent):
kind: Literal["OpaquePayload"] = "OpaquePayload"
tag: int
data: bytes


class KeyValuePairPayload(AbstractEvent):
kind: Literal["KeyValuePairPayload"] = "KeyValuePairPayload"
data: dict[
str, str | int | float | bool | list[int] | list[float] | list[str] | list[bool]
]


CustomPayload = Annotated[
Union[OpaquePayload, KeyValuePairPayload],
Field(discriminator="kind"),
]


class CustomEvent(AbstractEvent):
kind: Literal["Custom"] = "Custom"
payload: CustomPayload


Event = Annotated[
Union[GateEvent, MeasurementEvent, ResetEvent, CustomEvent],
Field(discriminator="kind"),
]
Source = Annotated[
Union[UserProgramSource, RuntimeSource],
Field(discriminator="kind"),
]


class EventRecord(BaseModel):
source: Source
event: Event


class Trace(BaseModel):
events: list[EventRecord] = Field(default_factory=list)

def add_runtime_event(self, event: Event, start_time_ns: int, end_time_ns: int):
self.events.append(
EventRecord(
source=RuntimeSource(
start_time=start_time_ns,
end_time=end_time_ns,
),
event=event,
)
)

def add_user_program_event(self, event: Event, index: int):
self.events.append(
EventRecord(
source=UserProgramSource(index=index),
event=event,
)
)

def filter(self, predicate: Callable[[EventRecord], bool]) -> "Trace":
return Trace(events=list(filter(predicate, self.events)))

def strip_custom_events(self) -> "Trace":
return self.filter(lambda r: not isinstance(r.event, CustomEvent))

def strip_opaque_custom_events(self) -> "Trace":
return self.filter(
lambda r: (
not (
isinstance(r.event, CustomEvent)
and isinstance(r.event.payload, OpaquePayload)
)
)
)

def get_runtime_trace(self) -> "Trace":
return self.filter(lambda e: isinstance(e.source, RuntimeSource))

def get_user_program_trace(self) -> "Trace":
return self.filter(lambda e: isinstance(e.source, UserProgramSource))
Loading
Loading