Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions hw/ip/acc/util/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,17 @@ py_binary(
],
)

py_binary(
name = "check_clobbered_regs",
srcs = ["check_clobbered_regs.py"],
deps = [
"//hw/ip/acc/util/shared:control_flow",
"//hw/ip/acc/util/shared:decode",
"//hw/ip/acc/util/shared:information_flow_analysis",
requirement("pyelftools"),
],
)

py_binary(
name = "check_const_time",
srcs = ["check_const_time.py"],
Expand Down
116 changes: 116 additions & 0 deletions hw/ip/acc/util/check_clobbered_regs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#!/usr/bin/env python3
# Copyright zeroRISC Inc.
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0
'''Check that clobbered-register annotations match actual writes.

Runs the information-flow analysis on a single subroutine to generate the
correct clobbered-register annotation, then compares against the docstring
in the assembly source file.
'''

import argparse
import re
import sys
from typing import Optional, Tuple

from shared.control_flow import subroutine_control_graph
from shared.decode import decode_elf
from shared.information_flow_analysis import get_subroutine_iflow


def _find_docstring(source_path: str, symbol: str) -> Optional[str]:
'''Find the /** ... */ docstring for a symbol in a source file.'''
with open(source_path, 'r') as f:
content = f.read()

label_m = re.search(
r'^' + re.escape(symbol) + r'\s*:', content, re.MULTILINE)
if label_m is None:
return None

before = content[:label_m.start()]
doc_m = None
for m in re.finditer(r'/\*\*.*?\*/', before, re.DOTALL):
doc_m = m
if doc_m is None or 'clobbered registers:' not in doc_m.group(0):
return None

between = content[doc_m.end():label_m.start()]
if re.search(r'^[a-zA-Z_]\w*\s*:', between, re.MULTILINE):
return None

return doc_m.group(0)


def _extract_declared(doc: str) -> Tuple[str, str]:
'''Extract declared clobbered registers and flags from a docstring.'''
m = re.search(r'clobbered registers:\s*(.+)', doc)
regs = m.group(1).strip() if m else ''
if m:
# Some annotations span multiple lines. Collect continuation lines
# until we hit the next docstring field (clobbered flag groups,
# @param, called subroutines, etc.) or a non-register line.
for cont in re.finditer(r'\n\s*\*\s+(.+)', doc[m.end():]):
text = cont.group(1).strip()
if re.match(r'(clobbered|@param|flags|called|\*/)',
text, re.IGNORECASE):
break
if not re.match(r'[xw\d]', text):
break
regs += ', ' + text

m = re.search(r'clobbered flag groups:\s*(.+)', doc, re.IGNORECASE)
flags = m.group(1).strip() if m else ''
return regs, flags


def main() -> int:
parser = argparse.ArgumentParser(
description='Check clobbered-register annotations in ACC assembly.')
parser.add_argument('elf', help='Path to the ACC ELF binary.')
parser.add_argument('--subroutine', required=True,
help='Subroutine to check.')
parser.add_argument('--source', '-s', required=True,
help='Assembly source file (.s).')
args = parser.parse_args()

program = decode_elf(args.elf)

doc = _find_docstring(args.source, args.subroutine)
if doc is None:
print('No clobbered-registers docstring found for {}'.format(
args.subroutine), file=sys.stderr)
return 1

declared_regs, declared_flags = _extract_declared(doc)

# Generate actual clobbers via information-flow analysis.
graph = subroutine_control_graph(program, args.subroutine)
ret_iflow, _, _ = get_subroutine_iflow(
program, graph, args.subroutine, {})
if not ret_iflow.exists:
print('No return paths found for {}'.format(
args.subroutine), file=sys.stderr)
return 1

lines = ret_iflow.clobbered().split('\n')
gen_regs = lines[0].replace('* clobbered registers: ', '')
gen_flags = lines[1].replace('* clobbered flag groups: ', '')

if gen_regs != declared_regs or gen_flags.lower() != declared_flags.lower():
print('{}:'.format(args.subroutine))
if gen_regs != declared_regs:
print(' declared: clobbered registers: {}'.format(declared_regs))
print(' actual: clobbered registers: {}'.format(gen_regs))
if gen_flags.lower() != declared_flags.lower():
print(' declared: clobbered flag groups: {}'.format(
declared_flags))
print(' actual: clobbered flag groups: {}'.format(gen_flags))
return 1

return 0


if __name__ == '__main__':
sys.exit(main())
40 changes: 40 additions & 0 deletions rules/acc.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,46 @@ acc_consttime_test = rule(
},
)

def _acc_clobbered_regs_test_impl(ctx):
"""Check that a subroutine's clobbered-register annotation is correct."""
elf = [f for t in ctx.attr.deps for f in t[OutputGroupInfo].elf.to_list()]
if len(elf) != 1:
fail("Expected only one .elf file in dependencies, got: " + str(elf))
elf = elf[0]

if len(ctx.files.srcs) != 1:
fail("Expected exactly one source file, got: " + str(ctx.files.srcs))
src = ctx.files.srcs[0]
script_content = "{checker} {elf} --subroutine {sub} --source {src}".format(
checker = ctx.executable._checker.short_path,
elf = elf.short_path,
sub = ctx.attr.subroutine,
src = src.short_path,
)
ctx.actions.write(
output = ctx.outputs.executable,
content = script_content,
)

runfiles = ctx.runfiles(files = [elf] + ctx.files.srcs)
runfiles = runfiles.merge(ctx.attr._checker[DefaultInfo].default_runfiles)
return [DefaultInfo(runfiles = runfiles)]

acc_clobbered_regs_test = rule(
implementation = _acc_clobbered_regs_test_impl,
test = True,
attrs = {
"srcs": attr.label_list(allow_files = True),
"deps": attr.label_list(providers = [OutputGroupInfo]),
"subroutine": attr.string(mandatory = True),
"_checker": attr.label(
default = "//hw/ip/acc/util:check_clobbered_regs",
executable = True,
cfg = "exec",
),
},
)

acc_insn_count_range = rule(
implementation = _acc_insn_count_range,
attrs = {
Expand Down
139 changes: 138 additions & 1 deletion sw/acc/crypto/tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,147 @@
# Licensed under the Apache License, Version 2.0, see LICENSE for details.
# SPDX-License-Identifier: Apache-2.0

load("//rules:acc.bzl", "acc_consttime_test", "acc_library", "acc_sim_test", "acc_sim_test_suite")
load("//rules:acc.bzl", "acc_clobbered_regs_test", "acc_consttime_test", "acc_library", "acc_sim_test", "acc_sim_test_suite")

package(default_visibility = ["//visibility:public"])

# Clobbered-register checks: verify that docstring annotations match
# the actual registers written (determined by information-flow analysis).
# Functions with N-dependent register ranges (e.g. montmul) are excluded.
[acc_clobbered_regs_test(
name = elf + "_" + sub + "_clobbered",
srcs = ["//sw/acc/crypto:" + src],
subroutine = sub,
deps = ["//sw/acc/crypto:" + elf],
) for sub, src, elf in [
# Ed25519
("ed25519_verify_var", "ed25519.s", "run_ed25519"),
("ed25519_sign_prehashed", "ed25519.s", "run_ed25519"),
("sc_clamp", "ed25519.s", "run_ed25519"),
("affine_to_ext", "ed25519.s", "run_ed25519"),
("ext_to_affine", "ed25519.s", "run_ed25519"),
("affine_encode", "ed25519.s", "run_ed25519"),
("affine_decode_var", "ed25519.s", "run_ed25519"),
("ext_scmul", "ed25519.s", "run_ed25519"),
("ext_scmul_var", "ed25519.s", "run_ed25519"),
("ext_double", "ed25519.s", "run_ed25519"),
("ext_add", "ed25519.s", "run_ed25519"),
("ext_equal_var", "ed25519.s", "run_ed25519"),
("fe_pow_2252m3", "ed25519.s", "run_ed25519"),
("fe_init", "field25519.s", "run_ed25519"),
("fe_mul", "field25519.s", "run_ed25519"),
("fe_square", "field25519.s", "run_ed25519"),
("fe_inv", "field25519.s", "run_ed25519"),
("sc_init", "ed25519_scalar.s", "run_ed25519"),
("sc_reduce", "ed25519_scalar.s", "run_ed25519"),
("sc_mul", "ed25519_scalar.s", "run_ed25519"),
("sha512_compact", "sha512_compact.s", "run_ed25519"),
("sha512_init", "sha512_interface.s", "run_ed25519"),
("sha512_update", "sha512_interface.s", "run_ed25519"),
("sha512_final", "sha512_interface.s", "run_ed25519"),
("sha512_format_blocks", "sha512_interface.s", "run_ed25519"),
("copy", "sha512_interface.s", "run_ed25519"),
("reverse_bytes", "sha512_interface.s", "run_ed25519"),
("bswap64", "sha512_interface.s", "run_ed25519"),
("sha512_pad_message", "sha512_padding.s", "run_ed25519"),
("bswap32", "sha512_padding.s", "run_ed25519"),
# SHA-512 (standalone)
("sha512", "sha512.s", "run_sha512"),
# SHA-256
("sha256", "sha256.s", "run_sha256"),
("sha256_process_block", "sha256.s", "run_sha256"),
# P-256
("setup_modp", "p256_base.s", "run_p256"),
("mul_modp", "p256_base.s", "run_p256"),
("mod_inv", "p256_base.s", "run_p256"),
("fetch_proj_randomize", "p256_base.s", "run_p256"),
("proj_add", "p256_base.s", "run_p256"),
("proj_double", "p256_base.s", "run_p256"),
("proj_to_affine", "p256_base.s", "run_p256"),
("mod_mul_256x256", "p256_base.s", "run_p256"),
("mod_mul_320x128", "p256_base.s", "run_p256"),
("p256_reduce", "p256_base.s", "run_p256"),
("scalar_mult_int", "p256_base.s", "run_p256"),
("p256_base_mult", "p256_base.s", "run_p256"),
("p256_random_scalar", "p256_base.s", "run_p256"),
("p256_generate_random_key", "p256_base.s", "run_p256"),
("p256_generate_k", "p256_base.s", "run_p256"),
("boolean_to_arithmetic", "p256_base.s", "run_p256"),
("p256_key_from_seed", "p256_base.s", "run_p256"),
("p256_scalar_remask", "p256_base.s", "run_p256"),
("trigger_fault_if_fg0_not_z", "p256_base.s", "run_p256"),
("trigger_fault_if_fg0_z", "p256_base.s", "run_p256"),
("p256_isoncurve", "p256_isoncurve.s", "run_p256"),
("p256_check_public_key", "p256_isoncurve.s", "run_p256"),
("p256_isoncurve_proj", "p256_isoncurve_proj.s", "run_p256"),
("p256_shared_key", "p256_shared_key.s", "run_p256"),
("arithmetic_to_boolean", "p256_shared_key.s", "run_p256"),
("arithmetic_to_boolean_mod", "p256_shared_key.s", "run_p256"),
("p256_sign", "p256_sign.s", "run_p256"),
("p256_verify", "p256_verify.s", "run_p256"),
("mod_inv_var", "p256_verify.s", "run_p256"),
("copy_share", "run_p256.s", "run_p256"),
# P-384
("p384_mulmod_p", "p384_base.s", "run_p384"),
("proj_add_p384", "p384_base.s", "run_p384"),
("proj_to_affine_p384", "p384_base.s", "run_p384"),
("p384_scalar_remask", "p384_base.s", "run_p384"),
("p384_base_mult_checked", "p384_base_mult.s", "run_p384"),
("p384_base_mult", "p384_base_mult.s", "run_p384"),
("p384_isoncurve", "p384_isoncurve.s", "run_p384"),
("p384_isoncurve_check", "p384_isoncurve.s", "run_p384"),
("p384_check_public_key", "p384_isoncurve.s", "run_p384"),
("trigger_fault_if_fg0_not_z", "p384_isoncurve.s", "run_p384"),
("trigger_input_error_if_fg0_not_z", "p384_isoncurve.s", "run_p384"),
("p384_isoncurve_proj", "p384_isoncurve_proj.s", "run_p384"),
("p384_isoncurve_proj_check", "p384_isoncurve_proj.s", "run_p384"),
("scalar_mult_int_p384", "p384_internal_mult.s", "run_p384"),
("store_proj_randomize", "p384_internal_mult.s", "run_p384"),
("p384_verify", "p384_verify.s", "run_p384"),
("p384_sign", "p384_sign.s", "run_p384"),
("p384_generate_k", "p384_keygen.s", "run_p384"),
("p384_generate_random_key", "p384_keygen.s", "run_p384"),
("p384_random_scalar", "p384_keygen.s", "run_p384"),
("p384_key_from_seed", "p384_keygen_from_seed.s", "run_p384"),
("p384_scalar_mult", "p384_scalar_mult.s", "run_p384"),
("mod_inv_n_p384", "p384_modinv.s", "run_p384"),
("p384_arithmetic_to_boolean", "p384_a2b.s", "run_p384"),
("p384_arithmetic_to_boolean_mod", "p384_a2b.s", "run_p384"),
("p384_boolean_to_arithmetic", "p384_b2a.s", "run_p384"),
("copy_share", "run_p384.s", "run_p384"),
("keypair_from_seed", "run_p384.s", "run_p384"),
("keypair_random", "run_p384.s", "run_p384"),
("shared_key", "run_p384.s", "run_p384"),
("shared_key_from_seed", "run_p384.s", "run_p384"),
# X25519
("X25519", "x25519.s", "x25519_sideload"),
("scalar_mult", "x25519.s", "x25519_sideload"),
("ladderstep", "x25519.s", "x25519_sideload"),
# RSA (only non-N-dependent functions)
("m0inv", "montmul.s", "rsa"),
("mul256_w30xw25", "montmul.s", "rsa"),
("mul256_w30xw2", "montmul.s", "rsa"),
("mul256_w20xw21", "mul.s", "rsa"),
("zero_work_buf", "rsa.s", "rsa"),
("cp_work_buf", "rsa.s", "rsa"),
# RSA keygen (only non-N-dependent functions)
("prepare_pm1qm1", "rsa_keygen.s", "run_rsa_keygen"),
("is_zero_mod_small_prime", "rsa_keygen.s", "run_rsa_keygen"),
# Boot
("attestation_secret_key_from_seed", "boot.s", "boot"),
# N-dependent RSA/bignum functions are excluded:
# montmul, mont_loop, modexp, modexp_crt, modexp_65537, compute_rr,
# double_and_reduce, cond_sub_to_reg, cond_sub_to_dmem, cond_sub_shifted,
# sel_sqr_or_sqrmul, montmul_mul1, bignum_mul, bignum_mul256,
# bignum_rshift1, bignum_lshift256, div, mod, gcd, gcd_cond_*,
# lcm, modinv, primality/miller_rabin/test_witness/reduce_modw/is_mont*,
# rsa_keygen, rsa_check_*, derive_*, check_*, recover_*, generate_*,
# fold_bignum*, relprime_*, modinv_f4, rsa_key_from_cofactor
#
# sha3_shake functions are excluded (no acc_binary ELF available):
# sha3_init, sha3_update, sha3_final, shake_xof, shake_out, keccakf
]]

acc_sim_test(
name = "ed25519_encode_decode_test",
srcs = [
Expand Down
Loading