From 7f30ca7b0a47c4af34592cd10fb6b94c0a4c2e91 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 29 Jul 2025 15:45:42 +0200 Subject: [PATCH 01/42] Merge branch 'feat-typing' into feat-numeric-types --- Cargo.toml | 1 + typing/Cargo.toml | 11 + typing/src/lib.rs | 705 +++++++++++++++++++++++++++++++++++++++++ typing/src/types.rs | 747 ++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 1464 insertions(+) create mode 100644 typing/Cargo.toml create mode 100644 typing/src/lib.rs create mode 100644 typing/src/types.rs diff --git a/Cargo.toml b/Cargo.toml index 5adb3e840..495ef6a7b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "air", "codegen/winterfell", "codegen/ace", + "typing", ] resolver = "2" diff --git a/typing/Cargo.toml b/typing/Cargo.toml new file mode 100644 index 000000000..442b5dd96 --- /dev/null +++ b/typing/Cargo.toml @@ -0,0 +1,11 @@ +[package] +name = "typing" +version = "0.1.0" +authors.workspace = true +license.workspace = true +repository.workspace = true +edition.workspace = true +rust-version.workspace = true + +[dev-dependencies] +pretty_assertions = "1.4.1" diff --git a/typing/src/lib.rs b/typing/src/lib.rs new file mode 100644 index 000000000..e4e5ba1fb --- /dev/null +++ b/typing/src/lib.rs @@ -0,0 +1,705 @@ +mod types; + +use std::fmt::Debug; + +pub use types::*; + +pub enum TypeError { + IncompatibleScalarTypes { + lhs: Option, + rhs: Option, + }, + IncompatibleShapes { + lhs: Option, + rhs: Option, + }, + IncompatibleType { + lhs: Option, + rhs: Option, + }, + TypeAlreadySet { + lhs: Option, + rhs: Option, + }, + NotASubtype { + lhs: Option, + rhs: Option, + }, + IncompatibleBinOp { + bin_ty: BinType, + }, +} + +pub trait Typing { + fn kind(&self) -> Option; + fn ty(&self) -> Option; + fn shape(&self) -> Option { + self.ty().and_then(|t| match t { + Type::Scalar(_) => ty!(_), + Type::Vector(_, len) => ty!(_[len]), + Type::Matrix(_, rows, cols) => ty!(_[rows, cols]), + }) + } + fn scalar_ty(&self) -> Option { + self.ty().scalar_ty() + } + fn ty_with_shape(&self, shape: impl Typing) -> Option { + let sty = self.scalar_ty(); + let shape = shape.shape(); + if sty.is_none() { + return shape; + } + match (self.scalar_ty().unwrap(), shape) { + (sty, None | Some(Type::Scalar(_))) => Some(Type::Scalar(Some(sty))), + (sty, Some(Type::Vector(_, len))) => Some(Type::Vector(Some(sty), len)), + (sty, Some(Type::Matrix(_, rows, cols))) => Some(Type::Matrix(Some(sty), rows, cols)), + } + } + fn is_scalar_felt(&self) -> bool { + matches!(self.scalar_ty(), sty!(felt)) + } + fn is_scalar_bool(&self) -> bool { + matches!(self.scalar_ty(), sty!(bool)) + } + fn is_scalar_int(&self) -> bool { + matches!(self.scalar_ty(), sty!(int)) + } + fn is_scalar(&self) -> bool { + matches!(self.ty(), Some(Type::Scalar(_))) + } + fn is_vector(&self) -> bool { + matches!(self.ty(), Some(Type::Vector(_, _))) + } + fn is_matrix(&self) -> bool { + matches!(self.ty(), Some(Type::Matrix(_, _, _))) + } + /// Returns true if the shape of `self` is a sub-shape of the shape of `other` + /// The shapes are compatible if: + /// - self is `?` (None) + /// - both are scalars + /// - both are vectors of the same length + /// - both are vectors with one of the lengths being `u32::MAX` + /// - both are matrices with the same number of rows and columns + /// - both are matrices with one or more of the rows or columns + /// being `u32::MAX`, the other pair (if any) being equal + /// + /// self\\other || _[r,c] | _[l] | _ | ? + /// ============||========|======|===|== + /// _[r,c] || y | n | n | n + /// _[l] || n | y | n | n + /// _ || n | n | y | n + /// ? || y | y | y | y + fn is_subshape(&self, other: &impl Typing) -> bool { + match (self.ty(), other.ty()) { + (None, _) => true, + (Some(Type::Scalar(_)), Some(Type::Scalar(_))) => true, + (Some(Type::Vector(_, len1)), Some(Type::Vector(_, len2))) => { + len1 == len2 || len1 == u32::MAX as usize || len2 == u32::MAX as usize + }, + (Some(Type::Matrix(_, rows1, cols1)), Some(Type::Matrix(_, rows2, cols2))) => { + (rows1 == rows2 || rows1 == u32::MAX as usize || rows2 == u32::MAX as usize) + && (cols1 == cols2 || cols1 == u32::MAX as usize || cols2 == u32::MAX as usize) + }, + _ => false, + } + } + + /// Returns true if the shape of `self` is compatible with the shape of `other` + /// The shapes are compatible if: + /// - either is `?` (None) + /// - both are scalars + /// - both are vectors of the same length + /// - both are vectors with one of the lengths being `u32::MAX` + /// - both are matrices with the same number of rows and columns + /// - both are matrices with one or more of the rows or columns + /// being `u32::MAX`, the other pair (if any) being equal + /// + /// self\\other || _[r,c] | _[l] | _ | ? + /// ============||========|======|===|== + /// _[r,c] || y | n | n | y + /// _[l] || n | y | n | y + /// _ || n | n | y | y + /// ? || y | y | y | y + /// + /// This is a more relaxed version of [Typing::is_subshape], + /// allowing for bi-directional compatibility checks. The only + /// difference is that it allows for `other` to be `?` (None). + fn is_shape_compatible(&self, other: &impl Typing) -> bool { + other.ty().is_none() || self.is_subshape(other) + } + /// Returns true if `self` is a subtype of `other` + /// Notation: + /// _ : ScalaType::Scalar(None) + /// Unknown scalar type + /// felt: ScalarType::Felt + /// Felt type + /// bool: ScalarType::Bool + /// Boolean type + /// int: ScalarType::Int + /// Integer type + /// + /// Subtyping rules: + /// - felt > bool > _ + /// - felt > int > _ + /// + /// Which means: + /// - `_` is a subtype of all scalar types + /// - `bool` is a subtype of `felt`: + /// a `bool` is a `felt with a `is_bool` property + /// - `int` is a subtype of `felt` + /// a `int` is a `felt` with the `constant` property + /// + /// self\\other || felt | bool | int | _ | + /// ============||======|======|=====|===| + /// felt || y | n | n | n | + /// bool || y | y | n | n | + /// int || y | n | y | n | + /// _ || y | y | y | y | + fn is_scalar_subtype(&self, other: &impl Typing) -> bool { + !matches!( + (self.scalar_ty(), other.scalar_ty()), + (sty!(felt), sty!(bool) | sty!(int) | sty!(_)) + | (sty!(bool), sty!(int) | sty!(_)) + | (sty!(int), sty!(bool) | sty!(_)) + ) + } + /// Returns true if `self` is a subtype of `other` + /// Notation: + /// ?: None + /// Unknown type + /// _: Type::Scalar(None) + /// Unknown scalar type + /// felt: Type::Scalar(Some(ScalarType::Felt)) + /// Felt type + /// bool: Type::Scalar(Some(ScalarType::Bool)) + /// Boolean type + /// int: Type::Scalar(Some(ScalarType::Int)) + /// Integer type + /// sty[len]: Type::Vector(Some(sty), len) + /// Vector of length `len` with scalar type `sty` + /// sty[rows, cols]: Type::Matrix(Some(sty), rows, cols) + /// Matrix with `rows` and `cols` with scalar type `sty` + /// + /// Subtyping rules: + /// ? > _ > felt > bool + /// ... > felt > int + /// ? > _[l] > felt[l] > bool[l] + /// ... > felt[l] > int[l] + /// ? > _[r, c] > felt[r, c] > bool[r, c] + /// ... > felt[r, c] > int[r, c] + /// Assuming shapes are compatible, this function checks if the scalar types, + /// with the added case of `?`, which all types are subtypes of. + /// See [Typing::is_scalar_subtype] for a more detailed explanation + /// of the subtyping rules of scalar types. + /// + /// self\\other || felt | bool | int | _ | ? | + /// ============||======|======|=====|===|===| + /// felt ||[ y | n | n | n]| n | + /// bool ||[ y | y | n | n]| n | + /// int ||[ y | n | y | n]| n | + /// _ ||[ y | y | y | y]| n | + /// ? || y | y | y | y | y | + /// + /// = self.is_scalar_subtype(other) | self == ? + /// [...] Denotes the result of the [Typing::is_scalar_subtype] method. + fn is_subtype(&self, other: &impl Typing) -> bool { + self.is_subshape(other) && self.is_scalar_subtype(other) + } + fn show_kind(&self) -> ShowOption { + ShowOption(self.kind()) + } + fn show_fn_ty(&self) -> ShowOption { + match self.kind() { + Some(Kind::Callable(fn_ty)) => ShowOption(Some(fn_ty)), + _ => ShowOption(None), + } + } + fn show_ty(&self) -> ShowOption { + ShowOption(self.ty()) + } + fn show_scalar_ty(&self) -> ShowOption { + ShowOption(self.scalar_ty()) + } + /// Returns the type of the current object, if it is known or can be inferred. + /// If the type is not known, it returns `None`. + /// If the type can be inferred, it returns the inferred type. + /// If the type cannot be inferred, it returns an appropriate error. + fn infer_ty(&self) -> Result, TypeError> { + Ok(self.ty()) + } +} + +pub trait ScalarTypeMut: Typing { + fn scalar_ty_mut(&mut self) -> &mut Option; + fn update_scalar_ty(&mut self, new_ty: Option) -> Result<(), TypeError> { + let ty = self.scalar_ty(); + if ty.is_none() { + // WARN: This should only be true before type inference + // Any None type should raise a diagnostic after type inference + *self.scalar_ty_mut() = new_ty; + } else if ty.is_scalar_subtype(&new_ty) { + // Allow widening of types + *self.scalar_ty_mut() = new_ty; + } else { + return Err(TypeError::IncompatibleScalarTypes { lhs: ty, rhs: new_ty }); + } + Ok(()) + } +} + +pub trait TypeMut: Typing + ScalarTypeMut { + fn ty_mut(&mut self) -> &mut Option; + fn update_ty(&mut self, new_ty: Option) -> Result<(), TypeError> { + let ty = self.ty(); + if ty.is_none() { + // WARN: This should only be true before type inference + // Any None type should raise a diagnostic after type inference + *self.ty_mut() = new_ty; + } else if ty.is_subtype(&new_ty) { + // Allow widening of types + *self.ty_mut() = new_ty; + } else { + return Err(TypeError::NotASubtype { lhs: ty, rhs: new_ty }); + } + Ok(()) + } +} + +#[derive(Clone, Debug, PartialEq, Eq)] +pub struct ShowOption(Option); + +impl core::fmt::Display for ShowOption { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match &self.0 { + None => f.write_str("!"), + Some(kind) => write!(f, "{kind}"), + } + } +} + +impl core::fmt::Display for ShowOption { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match &self.0 { + None => f.write_str("?"), + Some(fn_ty) => write!(f, "{fn_ty}"), + } + } +} + +impl core::fmt::Display for ShowOption { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match &self.0 { + None => f.write_str("?"), + Some(ty) => write!(f, "{ty}"), + } + } +} + +impl core::fmt::Display for ShowOption { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match &self.0 { + None => f.write_str("_"), + Some(sty) => write!(f, "{sty}"), + } + } +} + +impl Typing for ScalarType { + fn kind(&self) -> Option { + Some(Kind::Value(self.ty())) + } + fn ty(&self) -> Option { + Some(Type::Scalar(Some(*self))) + } + fn scalar_ty(&self) -> Option { + Some(*self) + } +} + +impl Typing for Type { + fn kind(&self) -> Option { + Some(Kind::Value(self.ty())) + } + fn ty(&self) -> Option { + Some(*self) + } + fn scalar_ty(&self) -> Option { + match self { + Type::Scalar(st) => *st, + Type::Vector(st, _) => *st, + Type::Matrix(st, _, _) => *st, + } + } +} + +impl ScalarTypeMut for Type { + fn scalar_ty_mut(&mut self) -> &mut Option { + match self { + Type::Scalar(st) => st, + Type::Vector(st, _) => st, + Type::Matrix(st, _, _) => st, + } + } +} + +impl Typing for FunctionType { + fn kind(&self) -> Option { + Some(Kind::Callable(self.clone())) + } + fn ty(&self) -> Option { + panic!("FunctionType does not have a concrete type") + } + fn scalar_ty(&self) -> Option { + panic!("FunctionType does not have a concrete scalar type") + } +} + +impl ScalarTypeMut for BinType { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.ret_mut().scalar_ty_mut() + } +} + +impl TypeMut for BinType { + fn ty_mut(&mut self) -> &mut Option { + self.ret_mut() + } +} + +impl Typing for BinType { + fn kind(&self) -> Option { + self.as_fn().kind() + } + fn ty(&self) -> Option { + self.infer_ty().ok()? + } + fn infer_ty(&self) -> Result, TypeError> { + match self { + BinType::Eq(.., Some(ret)) + | BinType::Add(.., Some(ret)) + | BinType::Sub(.., Some(ret)) + | BinType::Mul(.., Some(ret)) + | BinType::Exp(.., Some(ret)) => Ok(Some(*ret)), + BinType::Eq(.., None) => self.infer_bin_ty_eq(), + BinType::Add(.., None) => self.infer_bin_ty_add(), + BinType::Sub(.., None) => self.infer_bin_ty_sub(), + BinType::Mul(.., None) => self.infer_bin_ty_mul(), + BinType::Exp(.., None) => self.infer_bin_ty_exp(), + } + } +} + +impl ScalarTypeMut for Kind { + fn scalar_ty_mut(&mut self) -> &mut Option { + match self { + Kind::Value(ty) => ty.scalar_ty_mut(), + Kind::Callable(_) => panic!("Cannot mutate scalar type of a callable kind"), + } + } +} + +impl TypeMut for Kind { + fn ty_mut(&mut self) -> &mut Option { + match self { + Kind::Value(ty) => ty, + Kind::Callable(_) => panic!("Cannot mutate type of a callable kind"), + } + } +} + +impl Typing for Kind { + fn kind(&self) -> Option { + Some(self.clone()) + } + fn ty(&self) -> Option { + let Kind::Value(ty) = self else { + return None; + }; + *ty + } +} + +impl ScalarTypeMut for Option { + fn scalar_ty_mut(&mut self) -> &mut Option { + self + } +} + +impl ScalarTypeMut for Option { + fn scalar_ty_mut(&mut self) -> &mut Option { + match self { + Some(Type::Scalar(st)) => st, + Some(Type::Vector(st, _)) => st, + Some(Type::Matrix(st, _, _)) => st, + None => panic!("Cannot mutate scalar type of None"), + } + } +} + +impl TypeMut for Option { + fn ty_mut(&mut self) -> &mut Option { + self + } +} + +impl Typing for Option { + fn kind(&self) -> Option { + self.as_ref().and_then(|t| t.kind()) + } + fn ty(&self) -> Option { + self.as_ref().and_then(|t| t.ty()) + } + fn scalar_ty(&self) -> Option { + self.as_ref().and_then(|t| t.scalar_ty()) + } +} + +#[macro_export] +macro_rules! assert_subtype { + ($a:expr; !$b:expr) => { + eprintln!("assert_subtype!({}; !{})", stringify!($a), stringify!($b)); + let res = !$crate::Typing::is_subtype(&$a, &$b); + assert!( + res, + "{}: !{}\nError: {} is a subtype of {}", + $crate::Typing::show_ty(&$a), + $crate::Typing::show_ty(&$b), + $crate::Typing::show_ty(&$a), + $crate::Typing::show_ty(&$b), + ); + }; + ($a:expr; $b:expr) => { + eprintln!("assert_subtype!({}; {})", stringify!($a), stringify!($b)); + let res = $crate::Typing::is_subtype(&$a, &$b); + assert!( + res, + "{}: {}\nError: {} is a not subtype of {}", + $crate::Typing::show_ty(&$a), + $crate::Typing::show_ty(&$b), + $crate::Typing::show_ty(&$a), + $crate::Typing::show_ty(&$b), + ); + }; +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::{sty, ty}; + use pretty_assertions::assert_eq; + + #[test] + fn test_typing() { + assert_eq!(ty!(?).ty(), None); + assert_eq!(ty!(?).scalar_ty(), sty!(_)); + assert_eq!(ty!(_).ty(), Some(Type::Scalar(sty!(_)))); + assert_eq!(ty!(_).scalar_ty(), sty!(_)); + assert_eq!(ty!(felt).ty(), Some(Type::Scalar(sty!(felt)))); + assert_eq!(ty!(felt).scalar_ty(), sty!(felt)); + assert_eq!(ty!(bool).ty(), Some(Type::Scalar(sty!(bool)))); + assert_eq!(ty!(bool).scalar_ty(), sty!(bool)); + assert_eq!(ty!(int).ty(), Some(Type::Scalar(sty!(int)))); + assert_eq!(ty!(int).scalar_ty(), sty!(int)); + assert_eq!(ty!(_[5]).ty(), Some(Type::Vector(sty!(_), 5))); + assert_eq!(ty!(_[5]).scalar_ty(), sty!(_)); + assert_eq!(ty!(felt[5]).ty(), Some(Type::Vector(sty!(felt), 5))); + assert_eq!(ty!(felt[5]).scalar_ty(), sty!(felt)); + assert_eq!(ty!(bool[5]).ty(), Some(Type::Vector(sty!(bool), 5))); + assert_eq!(ty!(bool[5]).scalar_ty(), sty!(bool)); + assert_eq!(ty!(int[5]).ty(), Some(Type::Vector(sty!(int), 5))); + assert_eq!(ty!(int[5]).scalar_ty(), sty!(int)); + assert_eq!(ty!(_[3, 4]).ty(), Some(Type::Matrix(sty!(_), 3, 4))); + assert_eq!(ty!(_[3, 4]).scalar_ty(), sty!(_)); + assert_eq!(ty!(felt[3, 4]).ty(), Some(Type::Matrix(sty!(felt), 3, 4))); + assert_eq!(ty!(felt[3, 4]).scalar_ty(), sty!(felt)); + assert_eq!(ty!(bool[3, 4]).ty(), Some(Type::Matrix(sty!(bool), 3, 4))); + assert_eq!(ty!(bool[3, 4]).scalar_ty(), sty!(bool)); + assert_eq!(ty!(int[3, 4]).ty(), Some(Type::Matrix(sty!(int), 3, 4))); + assert_eq!(ty!(int[3, 4]).scalar_ty(), sty!(int)); + } + + #[test] + fn test_typing_subtype() { + assert_subtype!(ty!(?); ty!(?)); + assert_subtype!(ty!(?); ty!(_)); + assert_subtype!(ty!(?); ty!(felt)); + assert_subtype!(ty!(?); ty!(bool)); + assert_subtype!(ty!(?); ty!(int)); + assert_subtype!(ty!(?); ty!(_[5])); + assert_subtype!(ty!(?); ty!(felt[5])); + assert_subtype!(ty!(?); ty!(bool[5])); + assert_subtype!(ty!(?); ty!(int[5])); + assert_subtype!(ty!(?); ty!(_[3, 4])); + assert_subtype!(ty!(?); ty!(felt[3, 4])); + assert_subtype!(ty!(?); ty!(bool[3, 4])); + assert_subtype!(ty!(?); ty!(int[3, 4])); + + assert_subtype!(ty!(_); !ty!(?)); + assert_subtype!(ty!(_); ty!(_)); + assert_subtype!(ty!(_); ty!(felt)); + assert_subtype!(ty!(_); ty!(bool)); + assert_subtype!(ty!(_); ty!(int)); + assert_subtype!(ty!(_); !ty!(_[5])); + assert_subtype!(ty!(_); !ty!(felt[5])); + assert_subtype!(ty!(_); !ty!(bool[5])); + assert_subtype!(ty!(_); !ty!(int[5])); + assert_subtype!(ty!(_); !ty!(_[3, 4])); + assert_subtype!(ty!(_); !ty!(felt[3, 4])); + assert_subtype!(ty!(_); !ty!(bool[3, 4])); + assert_subtype!(ty!(_); !ty!(int[3, 4])); + + assert_subtype!(ty!(felt); !ty!(?)); + assert_subtype!(ty!(felt); !ty!(_)); + assert_subtype!(ty!(felt); ty!(felt)); + assert_subtype!(ty!(felt); !ty!(bool)); + assert_subtype!(ty!(felt); !ty!(int)); + assert_subtype!(ty!(felt); !ty!(_[5])); + assert_subtype!(ty!(felt); !ty!(felt[5])); + assert_subtype!(ty!(felt); !ty!(bool[5])); + assert_subtype!(ty!(felt); !ty!(int[5])); + assert_subtype!(ty!(felt); !ty!(_[3, 4])); + assert_subtype!(ty!(felt); !ty!(felt[3, 4])); + assert_subtype!(ty!(felt); !ty!(bool[3, 4])); + assert_subtype!(ty!(felt); !ty!(int[3, 4])); + + assert_subtype!(ty!(bool); !ty!(?)); + assert_subtype!(ty!(bool); !ty!(_)); + assert_subtype!(ty!(bool); ty!(felt)); + assert_subtype!(ty!(bool); ty!(bool)); + assert_subtype!(ty!(bool); !ty!(int)); + assert_subtype!(ty!(bool); !ty!(_[5])); + assert_subtype!(ty!(bool); !ty!(felt[5])); + assert_subtype!(ty!(bool); !ty!(bool[5])); + assert_subtype!(ty!(bool); !ty!(int[5])); + assert_subtype!(ty!(bool); !ty!(_[3, 4])); + assert_subtype!(ty!(bool); !ty!(felt[3, 4])); + assert_subtype!(ty!(bool); !ty!(bool[3, 4])); + assert_subtype!(ty!(bool); !ty!(int[3, 4])); + + assert_subtype!(ty!(int); !ty!(?)); + assert_subtype!(ty!(int); !ty!(_)); + assert_subtype!(ty!(int); ty!(felt)); + assert_subtype!(ty!(int); !ty!(bool)); + assert_subtype!(ty!(int); ty!(int)); + assert_subtype!(ty!(int); !ty!(_[5])); + assert_subtype!(ty!(int); !ty!(felt[5])); + assert_subtype!(ty!(int); !ty!(bool[5])); + assert_subtype!(ty!(int); !ty!(int[5])); + assert_subtype!(ty!(int); !ty!(_[3, 4])); + assert_subtype!(ty!(int); !ty!(felt[3, 4])); + assert_subtype!(ty!(int); !ty!(bool[3, 4])); + assert_subtype!(ty!(int); !ty!(int[3, 4])); + + assert_subtype!(ty!(_[5]); !ty!(?)); + assert_subtype!(ty!(_[5]); !ty!(_)); + assert_subtype!(ty!(_[5]); !ty!(felt)); + assert_subtype!(ty!(_[5]); !ty!(bool)); + assert_subtype!(ty!(_[5]); !ty!(int)); + assert_subtype!(ty!(_[5]); ty!(_[5])); + assert_subtype!(ty!(_[5]); ty!(felt[5])); + assert_subtype!(ty!(_[5]); ty!(bool[5])); + assert_subtype!(ty!(_[5]); ty!(int[5])); + assert_subtype!(ty!(_[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(_[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(_[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(_[5]); !ty!(int[3, 4])); + + assert_subtype!(ty!(felt[5]); !ty!(?)); + assert_subtype!(ty!(felt[5]); !ty!(_)); + assert_subtype!(ty!(felt[5]); !ty!(felt)); + assert_subtype!(ty!(felt[5]); !ty!(bool)); + assert_subtype!(ty!(felt[5]); !ty!(int)); + assert_subtype!(ty!(felt[5]); !ty!(_[5])); + assert_subtype!(ty!(felt[5]); ty!(felt[5])); + assert_subtype!(ty!(felt[5]); !ty!(bool[5])); + assert_subtype!(ty!(felt[5]); !ty!(int[5])); + assert_subtype!(ty!(felt[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(felt[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(felt[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(felt[5]); !ty!(int[3, 4])); + + assert_subtype!(ty!(bool[5]); !ty!(?)); + assert_subtype!(ty!(bool[5]); !ty!(_)); + assert_subtype!(ty!(bool[5]); !ty!(felt)); + assert_subtype!(ty!(bool[5]); !ty!(bool)); + assert_subtype!(ty!(bool[5]); !ty!(int)); + assert_subtype!(ty!(bool[5]); !ty!(_[5])); + assert_subtype!(ty!(bool[5]); ty!(felt[5])); + assert_subtype!(ty!(bool[5]); ty!(bool[5])); + assert_subtype!(ty!(bool[5]); !ty!(int[5])); + assert_subtype!(ty!(bool[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(bool[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(bool[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(bool[5]); !ty!(int[3, 4])); + + assert_subtype!(ty!(int[5]); !ty!(?)); + assert_subtype!(ty!(int[5]); !ty!(_)); + assert_subtype!(ty!(int[5]); !ty!(felt)); + assert_subtype!(ty!(int[5]); !ty!(bool)); + assert_subtype!(ty!(int[5]); !ty!(int)); + assert_subtype!(ty!(int[5]); !ty!(_[5])); + assert_subtype!(ty!(int[5]); ty!(felt[5])); + assert_subtype!(ty!(int[5]); !ty!(bool[5])); + assert_subtype!(ty!(int[5]); ty!(int[5])); + assert_subtype!(ty!(int[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(int[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(int[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(int[5]); !ty!(int[3, 4])); + + assert_subtype!(ty!(_[3, 4]); !ty!(?)); + assert_subtype!(ty!(_[3, 4]); !ty!(_)); + assert_subtype!(ty!(_[3, 4]); !ty!(felt)); + assert_subtype!(ty!(_[3, 4]); !ty!(bool)); + assert_subtype!(ty!(_[3, 4]); !ty!(int)); + assert_subtype!(ty!(_[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(_[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(_[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(_[3, 4]); !ty!(int[5])); + assert_subtype!(ty!(_[3, 4]); ty!(_[3, 4])); + assert_subtype!(ty!(_[3, 4]); ty!(felt[3, 4])); + assert_subtype!(ty!(_[3, 4]); ty!(bool[3, 4])); + assert_subtype!(ty!(_[3, 4]); ty!(int[3, 4])); + + assert_subtype!(ty!(felt[3, 4]); !ty!(?)); + assert_subtype!(ty!(felt[3, 4]); !ty!(_)); + assert_subtype!(ty!(felt[3, 4]); !ty!(felt)); + assert_subtype!(ty!(felt[3, 4]); !ty!(bool)); + assert_subtype!(ty!(felt[3, 4]); !ty!(int)); + assert_subtype!(ty!(felt[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(felt[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(felt[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(felt[3, 4]); !ty!(int[5])); + assert_subtype!(ty!(felt[3, 4]); !ty!(_[3, 4])); + assert_subtype!(ty!(felt[3, 4]); ty!(felt[3, 4])); + assert_subtype!(ty!(felt[3, 4]); !ty!(bool[3, 4])); + assert_subtype!(ty!(felt[3, 4]); !ty!(int[3, 4])); + + assert_subtype!(ty!(bool[3, 4]); !ty!(?)); + assert_subtype!(ty!(bool[3, 4]); !ty!(_)); + assert_subtype!(ty!(bool[3, 4]); !ty!(felt)); + assert_subtype!(ty!(bool[3, 4]); !ty!(bool)); + assert_subtype!(ty!(bool[3, 4]); !ty!(int)); + assert_subtype!(ty!(bool[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(bool[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(bool[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(bool[3, 4]); !ty!(int[5])); + assert_subtype!(ty!(bool[3, 4]); !ty!(_[3, 4])); + assert_subtype!(ty!(bool[3, 4]); ty!(felt[3, 4])); + assert_subtype!(ty!(bool[3, 4]); ty!(bool[3, 4])); + assert_subtype!(ty!(bool[3, 4]); !ty!(int[3, 4])); + + assert_subtype!(ty!(int[3, 4]); !ty!(?)); + assert_subtype!(ty!(int[3, 4]); !ty!(_)); + assert_subtype!(ty!(int[3, 4]); !ty!(felt)); + assert_subtype!(ty!(int[3, 4]); !ty!(bool)); + assert_subtype!(ty!(int[3, 4]); !ty!(int)); + assert_subtype!(ty!(int[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(int[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(int[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(int[3, 4]); !ty!(int[5])); + assert_subtype!(ty!(int[3, 4]); !ty!(_[3, 4])); + assert_subtype!(ty!(int[3, 4]); ty!(felt[3, 4])); + assert_subtype!(ty!(int[3, 4]); !ty!(bool[3, 4])); + assert_subtype!(ty!(int[3, 4]); ty!(int[3, 4])); + } +} diff --git a/typing/src/types.rs b/typing/src/types.rs new file mode 100644 index 000000000..cdc32d1d0 --- /dev/null +++ b/typing/src/types.rs @@ -0,0 +1,747 @@ +use crate::{TypeError, Typing}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ScalarType { + Felt, + Bool, + Int, +} + +impl core::fmt::Display for ScalarType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Felt => f.write_str("felt"), + Self::Bool => f.write_str("bool"), + Self::Int => f.write_str("int"), + } + } +} + +#[macro_export] +macro_rules! sty { + // for pattern matching + // equivalent to a `_` in a match or let expression + (any) => { + _ + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any: $name:ident) => { + $name + }; + (_) => { + None + }; + (felt) => { + Some($crate::ScalarType::Felt) + }; + (bool) => { + Some($crate::ScalarType::Bool) + }; + (int) => { + Some($crate::ScalarType::Int) + }; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum Type { + // annotation: sty + // where sty is the scalar type + Scalar(Option), + // annotation: `sty[len]` + // where len is the number of elements in the vector, + // and sty is the scalar type + Vector(Option, usize), + // annotation: `sty[rows, cols]` + // where rows and cols are the dimensions of the matrix, + // and sty is the scalar type + Matrix(Option, usize, usize), +} + +impl core::fmt::Display for Type { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Scalar(None) => f.write_str("_"), + Self::Vector(None, len) => write!(f, "_[{len}]"), + Self::Matrix(None, rows, cols) => write!(f, "_[{rows}, {cols}]"), + Self::Scalar(Some(sty)) => f.write_str(&sty.to_string()), + Self::Vector(Some(sty), len) => write!(f, "{sty}[{len}]"), + Self::Matrix(Some(sty), rows, cols) => write!(f, "{sty}[{rows}, {cols}]"), + } + } +} + +#[macro_export] +macro_rules! ty { + // for pattern matching + // equivalent to a `_` in a match or let expression + (any) => { + _ + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any: $name:ident) => { + $name + }; + (?) => { + None::<$crate::Type> + }; + (_) => { + Some($crate::Type::Scalar(None)) + }; + ($sty:ident) => { + Some($crate::Type::Scalar($crate::sty!($sty))) + }; + (_[$len:expr]) => { + Some($crate::Type::Vector($crate::sty!(_), $len)) + }; + ($sty:ident[$len:expr]) => { + Some($crate::Type::Vector($crate::sty!($sty), $len)) + }; + (_[$rows:expr, $cols:expr]) => { + Some($crate::Type::Matrix($crate::sty!(_), $rows, $cols)) + }; + ($sty:ident[$rows:expr, $cols:expr]) => { + Some($crate::Type::Matrix($crate::sty!($sty), $rows, $cols)) + }; +} + +pub struct Push(Vec>); +impl Push { + pub fn push(mut self, ty: Option) -> Self { + self.0.push(ty); + self + } +} + +#[macro_export] +macro_rules! tys { + ([$($args:tt)+]) => { + tys!(RES: Push(vec![]); $($args)+).0 + }; + (RES: $res:expr; ) => { + $res + }; + (RES: $res:expr; ?) => { + tys!(RES: $crate::Push::push($res, $crate::ty!(?));) + }; + (RES: $res:expr; _$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { + tys!(RES: $crate::Push::push($res, $crate::ty!(_$([$($spec)+])?)); $($($rest)+)?) + }; + (RES: $res:expr; $name:ident$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { + tys!(RES: $crate::Push::push($res, $crate::ty!($name$([$($spec)+])?)); $($($rest)+)?) + }; +} + +#[macro_export] +macro_rules! tty { + ([$($n1:ident$([$l1:expr])?),*]) => { + Vec::>::from([ + $($crate::tty!($n1$([$l1])?)),* + ]) + }; + ($name:ident[$len:expr]) => { + $crate::ty!(felt[$len]) + }; + ($name:ident) => { + $crate::ty!(felt[1]) + }; +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum FunctionType { + Evaluator(Vec>), + Function(Vec>, Option), +} + +impl FunctionType { + pub fn args(&self) -> &[Option] { + match self { + Self::Evaluator(args) => args, + Self::Function(args, _) => args, + } + } + + pub fn ret(&self) -> Option { + match self { + Self::Evaluator(_) => None, + Self::Function(_, ret) => *ret, + } + } +} + +impl core::fmt::Display for FunctionType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Evaluator(args) => { + f.write_str("ev(")?; + write!( + f, + "[{}]", + args.iter().map(|ty| ty.show_ty().to_string()).collect::>().join(", ") + )?; + f.write_str(")") + }, + Self::Function(args, ret) => { + f.write_str("fn(")?; + f.write_str( + &args.iter().map(|ty| ty.show_ty().to_string()).collect::>().join(", "), + )?; + f.write_str(") -> ")?; + if let Some(ret_type) = ret { + write!(f, "{}", ret_type) + } else { + f.write_str("?") + } + }, + } + } +} + +#[macro_export] +macro_rules! fty { + (ev ([])) => { + $crate::FunctionType::Evaluator(vec![]) + }; + (ev ([$($tty:tt)+])) => { + $crate::FunctionType::Evaluator($crate::tty!([$($tty)+])) + }; + (fn ($($arg:tt)*) -> $($ret:tt)+) => { + $crate::FunctionType::Function(tys!([$($arg)*]), $crate::ty!($($ret)+)) + }; +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum BinType { + Eq(Option, Option, Option), + Add(Option, Option, Option), + Sub(Option, Option, Option), + Mul(Option, Option, Option), + Exp(Option, Option, Option), +} + +impl BinType { + pub fn lhs(&self) -> Option { + match self { + Self::Eq(lhs, _, _) + | Self::Add(lhs, _, _) + | Self::Sub(lhs, _, _) + | Self::Mul(lhs, _, _) + | Self::Exp(lhs, _, _) => *lhs, + } + } + + pub fn lhs_mut(&mut self) -> &mut Option { + match self { + Self::Eq(lhs, _, _) + | Self::Add(lhs, _, _) + | Self::Sub(lhs, _, _) + | Self::Mul(lhs, _, _) + | Self::Exp(lhs, _, _) => lhs, + } + } + + pub fn rhs(&self) -> Option { + match self { + Self::Eq(_, rhs, _) + | Self::Add(_, rhs, _) + | Self::Sub(_, rhs, _) + | Self::Mul(_, rhs, _) + | Self::Exp(_, rhs, _) => *rhs, + } + } + + pub fn rhs_mut(&mut self) -> &mut Option { + match self { + Self::Eq(_, rhs, _) + | Self::Add(_, rhs, _) + | Self::Sub(_, rhs, _) + | Self::Mul(_, rhs, _) + | Self::Exp(_, rhs, _) => rhs, + } + } + + pub fn ret(&self) -> Option { + match self { + Self::Eq(_, _, ret) + | Self::Add(_, _, ret) + | Self::Sub(_, _, ret) + | Self::Mul(_, _, ret) + | Self::Exp(_, _, ret) => *ret, + } + } + + pub fn ret_mut(&mut self) -> &mut Option { + match self { + Self::Eq(_, _, ret) + | Self::Add(_, _, ret) + | Self::Sub(_, _, ret) + | Self::Mul(_, _, ret) + | Self::Exp(_, _, ret) => ret, + } + } + + pub fn as_fn(&self) -> FunctionType { + match self { + Self::Eq(lhs, rhs, ret) + | Self::Add(lhs, rhs, ret) + | Self::Sub(lhs, rhs, ret) + | Self::Mul(lhs, rhs, ret) + | Self::Exp(lhs, rhs, ret) => FunctionType::Function(vec![*lhs, *rhs], *ret), + } + } + /// Returns a new [BinType] with all types casted to their [Type::Scalar] equivalent: + /// - `?` -> `_` + /// - `sty` -> `sty` + /// - `sty[len]` -> `sty` + /// - `sty[rows, cols]` -> `sty` + /// + /// This corresponds to the shape `_`. + pub fn without_shape(&self) -> Self { + match self { + Self::Eq(lhs, rhs, ret) => Self::Eq( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + Self::Add(lhs, rhs, ret) => Self::Add( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + Self::Sub(lhs, rhs, ret) => Self::Sub( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + Self::Mul(lhs, rhs, ret) => Self::Mul( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + Self::Exp(lhs, rhs, ret) => Self::Exp( + (*lhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*rhs).and_then(|ty| ty.ty_with_shape(ty!(_))), + (*ret).and_then(|ty| ty.ty_with_shape(ty!(_))), + ), + } + } +} + +impl core::fmt::Display for BinType { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Eq(lhs, rhs, None) => write!(f, "{} = {}", lhs.show_ty(), rhs.show_ty()), + Self::Add(lhs, rhs, None) => write!(f, "{} + {}", lhs.show_ty(), rhs.show_ty()), + Self::Sub(lhs, rhs, None) => write!(f, "{} - {}", lhs.show_ty(), rhs.show_ty()), + Self::Mul(lhs, rhs, None) => write!(f, "{} * {}", lhs.show_ty(), rhs.show_ty()), + Self::Exp(lhs, rhs, None) => write!(f, "{} ^ {}", lhs.show_ty(), rhs.show_ty()), + Self::Eq(lhs, rhs, ret) => { + write!(f, "{} = {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + Self::Add(lhs, rhs, ret) => { + write!(f, "{} + {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + Self::Sub(lhs, rhs, ret) => { + write!(f, "{} - {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + Self::Mul(lhs, rhs, ret) => { + write!(f, "{} * {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + Self::Exp(lhs, rhs, ret) => { + write!(f, "{} ^ {} -> {}", lhs.show_ty(), rhs.show_ty(), ret.show_ty()) + }, + } + } +} + +#[macro_export] +macro_rules! bty { + ($($bty:tt)+ -> $($ret:tt)+) => {{ + let b = $crate::bty!($($bty)+); + b.ret_mut().replace($crate::ty!($($ret)+)); + b + }}; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident = $($rhs:tt)+) => { + $crate::BinType::Eq($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? = $($rhs:tt)+) => { + $crate::BinType::Eq($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? = $($rhs:tt)+) => { + $crate::BinType::Eq($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? = $($rhs:tt)+) => { + $crate::BinType::Eq($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident + $($rhs:tt)+) => { + $crate::BinType::Add($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? + $($rhs:tt)+) => { + $crate::BinType::Add($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? + $($rhs:tt)+) => { + $crate::BinType::Add($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? + $($rhs:tt)+) => { + $crate::BinType::Add($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident - $($rhs:tt)+) => { + $crate::BinType::Sub($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? - $($rhs:tt)+) => { + $crate::BinType::Sub($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? - $($rhs:tt)+) => { + $crate::BinType::Sub($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? - $($rhs:tt)+) => { + $crate::BinType::Sub($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident * $($rhs:tt)+) => { + $crate::BinType::Mul($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? * $($rhs:tt)+) => { + $crate::BinType::Mul($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? * $($rhs:tt)+) => { + $crate::BinType::Mul($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? * $($rhs:tt)+) => { + $crate::BinType::Mul($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + // for pattern matching + // equivalent to a `$name` in a match or let expression + (any:$name:ident ^ $($rhs:tt)+) => { + $crate::BinType::Exp($crate::ty!(any:$name), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (? ^ $($rhs:tt)+) => { + $crate::BinType::Exp($crate::ty!(?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + (_$([$($spec:tt)+])? ^ $($rhs:tt)+) => { + $crate::BinType::Exp($crate::ty!(_$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; + ($sty:ident$([$($spec:tt)+])? ^ $($rhs:tt)+) => { + $crate::BinType::Exp($crate::ty!($sty$([$($spec)+])?), $crate::ty!($($rhs)+), $crate::ty!(?)) + }; +} + +impl BinType { + /// Returns the type of the result of an equality based on the types + /// of the left-hand side and right-hand side operands. + /// If the types are not compatible, it returns a [TypeError::IncompatibleBinOp]. + /// + /// Assuming shapes are compatible, the following table shows the result type + /// based on the scalar types of the operands: + /// ? == ? || felt | bool | int | _ | ? + /// =========||======|======|======|======|===== + /// felt || bool | bool | bool | bool | ? + /// bool || bool | bool | bool | bool | ? + /// int || bool | bool | bool | bool | ? + /// _ || bool | bool | bool | bool | ? + /// ? || ? | ? | ? | ? | ? + /// + /// So, the result type of an equality is: + /// - an error if lhs or rhs don't have a compatible shape, + /// - symmetric over the operands, + /// - any == ? -> ?, + /// - always `bool` otherwise + pub fn infer_bin_ty_eq(&self) -> Result, TypeError> { + if let Some(ret) = self.ret() { + return Ok(Some(ret)); + } + let lhs = self.lhs(); + let rhs = self.rhs(); + if lhs.is_none() || rhs.is_none() { + return Ok(ty!(?)); + } + if self.lhs().is_shape_compatible(&self.rhs()) { + Ok(ty!(bool)) + } else { + Err(TypeError::IncompatibleBinOp { bin_ty: *self }) + } + } + + /// Returns the type of the result of an addition based on the types + /// of the left-hand side and right-hand side operands. + /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. + /// + /// based on the scalar types of the operands: + /// ? + ? || felt | bool | int | _ | ? + /// =========||======|======|======|======|===== + /// felt || felt | felt | felt | felt | felt + /// bool || felt | felt | felt | felt | felt + /// int || felt | felt | int | _ | ? + /// _ || felt | felt | _ | _ | ? + /// ? || felt | felt | ? | ? | ? + /// + /// So, the result type of an addition is: + /// - an error if lhs or rhs is not a scalar type or `?`, + /// - symmetric over the operands, + /// - felt + any -> felt + /// - bool + any -> felt + /// - ? + any -> ? + /// - int + int -> int + /// - everything else is an unknown scalar type `_` + pub fn infer_bin_ty_add(&self) -> Result, TypeError> { + if let Some(ret) = self.ret() { + return Ok(Some(ret)); + } + let lhs = self.lhs(); + let rhs = self.rhs(); + if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { + return Err(TypeError::IncompatibleShapes { lhs, rhs }); + } + match self { + bty!(felt + any) | bty!(any + felt) => Ok(ty!(felt)), + bty!(bool + any) | bty!(any + bool) => Ok(ty!(felt)), + bty!(? + any) | bty!(any + ?) => Ok(ty!(?)), + bty!(int + int) => Ok(ty!(int)), + _ => Ok(ty!(_)), + } + } + + /// Returns the type of the result of a substraction based on the types + /// of the left-hand side and right-hand side operands. + /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. + /// + /// based on the scalar types of the operands: + /// ? - ? || felt | bool | int | _ | ? + /// =========||======|======|======|======|===== + /// felt || felt | felt | felt | felt | felt + /// bool || felt | felt | felt | felt | felt + /// int || felt | felt | int | _ | ? + /// _ || felt | felt | _ | _ | ? + /// ? || felt | felt | ? | ? | ? + /// + /// So, the result type of a substraction is: + /// - an error if either lhs or rhs is not a scalar type or `?`, + /// - symmetric over the operands, + /// - felt - any -> felt + /// - bool - any -> felt + /// - ? - any -> ? + /// - int - int -> int + /// - everything else is an unknown scalar type `_` + /// + /// This is the same as `infer_bin_ty_add`, so it reuses that method. + pub fn infer_bin_ty_sub(&self) -> Result, TypeError> { + self.infer_bin_ty_add() + } + + /// Returns the type of the result of a multiplication based on the types + /// of the left-hand side and right-hand side operands. + /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. + /// + /// based on the scalar types of the operands: + /// ? * ? || felt | bool | int | _ | ? + /// =========||======|======|======|======|===== + /// felt || felt | felt | felt | felt | felt + /// bool || felt | bool | int | _ | ? + /// int || felt | int | int | _ | ? + /// _ || felt | _ | _ | _ | ? + /// ? || felt | ? | ? | ? | ? + /// + /// So, the result type of a multiplication is: + /// - an error if either lhs or rhs is not a scalar type or `?`, + /// - symmetric over the operands, + /// - felt * any -> felt + /// - ? * any -> ? + /// - _ * any -> _ + /// - int * int -> int + /// - bool * x -> x + /// - everything else is an unknown scalar type `_` + pub fn infer_bin_ty_mul(&self) -> Result, TypeError> { + if let Some(ret) = self.ret() { + return Ok(Some(ret)); + } + let lhs = self.lhs(); + let rhs = self.rhs(); + if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { + return Err(TypeError::IncompatibleShapes { lhs, rhs }); + } + match self { + bty!(felt * any) | bty!(any * felt) => Ok(ty!(felt)), + bty!(? * any) | bty!(any * ?) => Ok(ty!(?)), + bty!(_ * any) | bty!(any * _) => Ok(ty!(_)), + bty!(int * int) => Ok(ty!(int)), + bty!(bool * any:x) | bty!(any:x * bool) => Ok(*x), + _ => Ok(ty!(_)), + } + } + + /// Returns the type of the result of an exponentiation based on the types + /// of the left-hand side and right-hand side operands. + /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleBinOp]. + /// + /// based on the scalar types of the operands: + /// ? ^ ? || felt | bool | int | _ | ? + /// =========||======|======|======|======|===== + /// felt || err | err | felt | _ | ? + /// bool || err | err | bool | _ | ? + /// int || err | err | int | _ | ? + /// _ || err | err | _ | _ | ? + /// ? || err | err | ? | ? | ? + /// + /// + /// So, the result type of an exponentiation is: + /// - an error if either lhs or rhs is not a scalar type or `?`, + /// - an error if the rhs is not an int or `?`, + /// - any ^ ? -> ?, + /// - ? ^ any -> ?, + /// - any ^ _ -> _, + /// - any:x ^ int -> lhs, + /// + /// Because: + /// - it is an error if either lhs or rhs is not a scalar type or `?`, + /// - it is an error if rhs is not an int or `?`, + /// - a bool to any power is still a bool: + /// - 0^n = 0 + /// - 1^n = 1 + /// - a felt to any power is still a felt + /// - an int to any power is still an int + /// - a _ to any power is still a _ + /// - a ? to any power is still a ? + pub fn infer_bin_ty_exp(&self) -> Result, TypeError> { + if let Some(ret) = self.ret() { + return Ok(Some(ret)); + } + let lhs = self.lhs(); + let rhs = self.rhs(); + if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { + return Err(TypeError::IncompatibleBinOp { bin_ty: *self }); + } + match self { + bty!(any ^ felt) | bty!(any ^ bool) => { + Err(TypeError::IncompatibleBinOp { bin_ty: *self }) + }, + bty!(any ^ ?) | bty!(? ^ any) => Ok(ty!(?)), + bty!(any ^ _) => Ok(ty!(_)), + bty!(any:lhs ^ int) => Ok(*lhs), + _ => unreachable!("Undefined case for infer_bin_ty_exp: {self}"), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Kind { + Value(Option), + Callable(FunctionType), +} + +impl core::fmt::Display for Kind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Value(ty) => write!(f, "{}", ty.show_ty()), + Self::Callable(fty) => write!(f, "{}", fty.show_fn_ty()), + } + } +} + +#[macro_export] +macro_rules! kind { + (ev $($spec:tt)+) => { + $crate::Kind::Callable($crate::fty!(ev $($spec)+)) + }; + (fn ($($args:tt)*) -> $($ret:tt)+) => { + $crate::Kind::Callable($crate::fty!(fn ($($args)*) -> $($ret)+)) + }; + ($($spec:tt)+) => { + $crate::Kind::Value($crate::ty!($($spec)+)) + }; +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_macro_scalar_type() { + assert_eq!(sty!(_), None::); + assert_eq!(sty!(felt), Some(ScalarType::Felt)); + assert_eq!(sty!(bool), Some(ScalarType::Bool)); + assert_eq!(sty!(int), Some(ScalarType::Int)); + } + + #[test] + fn test_macro_type() { + assert_eq!(ty!(?), None::); + assert_eq!(ty!(_), Some(Type::Scalar(None))); + assert_eq!(ty!(felt), Some(Type::Scalar(Some(ScalarType::Felt)))); + assert_eq!(ty!(bool), Some(Type::Scalar(Some(ScalarType::Bool)))); + assert_eq!(ty!(int), Some(Type::Scalar(Some(ScalarType::Int)))); + assert_eq!(ty!(_[5]), Some(Type::Vector(None, 5))); + assert_eq!(ty!(int[5]), Some(Type::Vector(Some(ScalarType::Int), 5))); + assert_eq!(ty!(_[3, 4]), Some(Type::Matrix(None, 3, 4))); + assert_eq!(ty!(felt[3, 4]), Some(Type::Matrix(Some(ScalarType::Felt), 3, 4))); + } + + #[test] + fn test_macro_trace_segment_type() { + assert_eq!(tty!(a), ty!(felt[1])); + assert_eq!(tty!(a[5]), ty!(felt[5])); + assert_eq!(tty!([]), Vec::>::new()); + assert_eq!(tty!([a]), vec![ty!(felt[1])]); + assert_eq!(tty!([a[5]]), vec![ty!(felt[5])]); + assert_eq!(tty!([a[1], b[3]]), vec![ty!(felt[1]), ty!(felt[3])]); + } + + #[test] + fn test_macro_function_type() { + assert_eq!(fty!(ev([])), FunctionType::Evaluator(vec![])); + assert_eq!(fty!(ev([a])), FunctionType::Evaluator(vec![ty!(felt[1])])); + assert_eq!(fty!(ev([a[5]])), FunctionType::Evaluator(vec![ty!(felt[5])])); + assert_eq!(fty!(ev([a, b[3]])), FunctionType::Evaluator(vec![ty!(felt[1]), ty!(felt[3])])); + assert_eq!( + fty!(ev([a[1], b[3]])), + FunctionType::Evaluator(vec![ty!(felt[1]), ty!(felt[3])]) + ); + assert_eq!( + fty!(ev([a[1], b[3]])), + FunctionType::Evaluator(vec![ty!(felt[1]), ty!(felt[3])]) + ); + + assert_eq!(fty!(fn(int) -> felt), FunctionType::Function(vec![ty!(int)], ty!(felt))); + assert_eq!( + fty!(fn(int[5]) -> felt[3, 4]), + FunctionType::Function(vec![ty!(int[5])], ty!(felt[3, 4]),) + ); + assert_eq!( + fty!(fn(int[5], felt) -> felt[3, 4]), + FunctionType::Function(vec![ty!(int[5]), ty!(felt)], ty!(felt[3, 4]),) + ); + assert_eq!( + fty!(fn(int[5], felt, bool[3, 4]) -> felt[3, 4]), + FunctionType::Function(vec![ty!(int[5]), ty!(felt), ty!(bool[3, 4]),], ty!(felt[3, 4]),) + ); + } + + #[test] + fn test_macro_bin_type() { + assert_eq!(bty!(int + felt), BinType::Add(ty!(int), ty!(felt), ty!(?))); + assert_eq!(bty!(_ - felt), BinType::Sub(ty!(_), ty!(felt), ty!(?))); + assert_eq!(bty!(? = felt), BinType::Eq(ty!(?), ty!(felt), ty!(?))); + assert_eq!(bty!(int + ?), BinType::Add(ty!(int), ty!(?), ty!(?))); + assert_eq!(bty!(int - felt), BinType::Sub(ty!(int), ty!(felt), ty!(?))); + assert_eq!(bty!(int[2] * felt[2]), BinType::Mul(ty!(int[2]), ty!(felt[2]), ty!(?))); + assert_eq!(bty!(int[2, 3] ^ _), BinType::Exp(ty!(int[2, 3]), ty!(_), ty!(?))); + assert_eq!(bty!(bool[5] = _[5]), BinType::Eq(ty!(bool[5]), ty!(_[5]), ty!(?))); + } + + #[test] + fn test_macro_kind() { + assert_eq!(kind!(ev([])), Kind::Callable(fty!(ev([])))); + assert_eq!(kind!(ev([a])), Kind::Callable(fty!(ev([a])))); + assert_eq!(kind!(fn(int) -> felt), Kind::Callable(fty!(fn(int) -> felt))); + assert_eq!(kind!(int), Kind::Value(ty!(int))); + assert_eq!(kind!(_), Kind::Value(ty!(_))); + assert_eq!(kind!(bool[3, 4]), Kind::Value(ty!(bool[3, 4]))); + } +} From 474de8a0b81afa8cf2b4374d7215543702c99417 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 29 Jul 2025 16:53:31 +0200 Subject: [PATCH 02/42] docs(typing): add a NOTE for BinType::infer_bin_ty_sub, for issue #432 --- typing/src/types.rs | 15 +++++++++++++-- 1 file changed, 13 insertions(+), 2 deletions(-) diff --git a/typing/src/types.rs b/typing/src/types.rs index cdc32d1d0..82c21d6f5 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -527,11 +527,22 @@ impl BinType { /// - symmetric over the operands, /// - felt - any -> felt /// - bool - any -> felt - /// - ? - any -> ? /// - int - int -> int + /// - ? - any -> ? /// - everything else is an unknown scalar type `_` /// - /// This is the same as `infer_bin_ty_add`, so it reuses that method. + /// This is the same as [BinType::infer_bin_ty_add], so it reuses that method. + /// + /// NOTE: if we refine the types as described in #432, this method will need to be + /// updated to handle the substraction of `bool` and `int` types correctly. + /// This will no longer be symmetric over the operands! + /// Because: + /// - 0 - bool = - bool -> felt + /// - bool - 0 -> bool + /// - 0 - int = - int -> int (or error depending on the design) + /// - int - 0 -> int + /// - 1 - bool -> bool + /// - bool - 1 -> felt pub fn infer_bin_ty_sub(&self) -> Result, TypeError> { self.infer_bin_ty_add() } From 1255bc9d764f43097639899b8910a5d7422bfbbb Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 29 Jul 2025 17:13:51 +0200 Subject: [PATCH 03/42] chores(typing): make format --- typing/src/lib.rs | 23 +++++++++++------------ typing/src/types.rs | 20 ++++++++++---------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index e4e5ba1fb..a721bba22 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -80,8 +80,8 @@ pub trait Typing { /// - both are vectors of the same length /// - both are vectors with one of the lengths being `u32::MAX` /// - both are matrices with the same number of rows and columns - /// - both are matrices with one or more of the rows or columns - /// being `u32::MAX`, the other pair (if any) being equal + /// - both are matrices with one or more of the rows or columns being `u32::MAX`, the other pair + /// (if any) being equal /// /// self\\other || _[r,c] | _[l] | _ | ? /// ============||========|======|===|== @@ -111,8 +111,8 @@ pub trait Typing { /// - both are vectors of the same length /// - both are vectors with one of the lengths being `u32::MAX` /// - both are matrices with the same number of rows and columns - /// - both are matrices with one or more of the rows or columns - /// being `u32::MAX`, the other pair (if any) being equal + /// - both are matrices with one or more of the rows or columns being `u32::MAX`, the other pair + /// (if any) being equal /// /// self\\other || _[r,c] | _[l] | _ | ? /// ============||========|======|===|== @@ -144,10 +144,8 @@ pub trait Typing { /// /// Which means: /// - `_` is a subtype of all scalar types - /// - `bool` is a subtype of `felt`: - /// a `bool` is a `felt with a `is_bool` property - /// - `int` is a subtype of `felt` - /// a `int` is a `felt` with the `constant` property + /// - `bool` is a subtype of `felt`: a `bool` is a `felt with a `is_bool` property + /// - `int` is a subtype of `felt`: a `int` is a `felt` with the `constant` property /// /// self\\other || felt | bool | int | _ | /// ============||======|======|=====|===| @@ -327,7 +325,7 @@ impl Typing for Type { match self { Type::Scalar(st) => *st, Type::Vector(st, _) => *st, - Type::Matrix(st, _, _) => *st, + Type::Matrix(st, ..) => *st, } } } @@ -337,7 +335,7 @@ impl ScalarTypeMut for Type { match self { Type::Scalar(st) => st, Type::Vector(st, _) => st, - Type::Matrix(st, _, _) => st, + Type::Matrix(st, ..) => st, } } } @@ -430,7 +428,7 @@ impl ScalarTypeMut for Option { match self { Some(Type::Scalar(st)) => st, Some(Type::Vector(st, _)) => st, - Some(Type::Matrix(st, _, _)) => st, + Some(Type::Matrix(st, ..)) => st, None => panic!("Cannot mutate scalar type of None"), } } @@ -484,9 +482,10 @@ macro_rules! assert_subtype { #[cfg(test)] mod tests { + use pretty_assertions::assert_eq; + use super::*; use crate::{sty, ty}; - use pretty_assertions::assert_eq; #[test] fn test_typing() { diff --git a/typing/src/types.rs b/typing/src/types.rs index 82c21d6f5..c7e1ea184 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -223,21 +223,21 @@ pub enum BinType { impl BinType { pub fn lhs(&self) -> Option { match self { - Self::Eq(lhs, _, _) - | Self::Add(lhs, _, _) - | Self::Sub(lhs, _, _) - | Self::Mul(lhs, _, _) - | Self::Exp(lhs, _, _) => *lhs, + Self::Eq(lhs, ..) + | Self::Add(lhs, ..) + | Self::Sub(lhs, ..) + | Self::Mul(lhs, ..) + | Self::Exp(lhs, ..) => *lhs, } } pub fn lhs_mut(&mut self) -> &mut Option { match self { - Self::Eq(lhs, _, _) - | Self::Add(lhs, _, _) - | Self::Sub(lhs, _, _) - | Self::Mul(lhs, _, _) - | Self::Exp(lhs, _, _) => lhs, + Self::Eq(lhs, ..) + | Self::Add(lhs, ..) + | Self::Sub(lhs, ..) + | Self::Mul(lhs, ..) + | Self::Exp(lhs, ..) => lhs, } } From 9f17e2496e7be680a8fc74132888dc9b5294d2fb Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 30 Jul 2025 12:03:50 +0200 Subject: [PATCH 04/42] fix(typing): adapt to old parser/ast/types api --- typing/src/lib.rs | 15 +++++++++++++-- typing/src/types.rs | 31 ++++++++++++++++++------------- 2 files changed, 31 insertions(+), 15 deletions(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index a721bba22..ea1c9d2b4 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -73,6 +73,17 @@ pub trait Typing { fn is_matrix(&self) -> bool { matches!(self.ty(), Some(Type::Matrix(_, _, _))) } + /// Returns true if this type is an aggregate + #[inline] + fn is_aggregate(&self) -> bool { + self.is_vector() || self.is_matrix() + } + + /// Returns true if this type is a valid iterable in a comprehension + #[inline] + fn is_iterable(&self) -> bool { + self.is_vector() + } /// Returns true if the shape of `self` is a sub-shape of the shape of `other` /// The shapes are compatible if: /// - self is `?` (None) @@ -354,13 +365,13 @@ impl Typing for FunctionType { impl ScalarTypeMut for BinType { fn scalar_ty_mut(&mut self) -> &mut Option { - self.ret_mut().scalar_ty_mut() + self.result_mut().scalar_ty_mut() } } impl TypeMut for BinType { fn ty_mut(&mut self) -> &mut Option { - self.ret_mut() + self.result_mut() } } diff --git a/typing/src/types.rs b/typing/src/types.rs index c7e1ea184..6d2819d44 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -1,6 +1,6 @@ use crate::{TypeError, Typing}; -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq)] pub enum ScalarType { Felt, Bool, @@ -43,7 +43,8 @@ macro_rules! sty { }; } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +/// The types of values which can be represented in an AirScript program +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq)] pub enum Type { // annotation: sty // where sty is the scalar type @@ -148,9 +149,13 @@ macro_rules! tty { }; } -#[derive(Debug, Clone, PartialEq, Eq)] +/// Represents the type signature of a function +#[derive(Hash, Debug, Clone, PartialEq, Eq)] pub enum FunctionType { + /// An evaluator function, which has no results, and has + /// a complex type signature due to the nature of trace bindings Evaluator(Vec>), + /// A standard function with one or more inputs, and a result Function(Vec>, Option), } @@ -162,7 +167,7 @@ impl FunctionType { } } - pub fn ret(&self) -> Option { + pub fn result(&self) -> Option { match self { Self::Evaluator(_) => None, Self::Function(_, ret) => *ret, @@ -211,7 +216,7 @@ macro_rules! fty { }; } -#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq)] pub enum BinType { Eq(Option, Option, Option), Add(Option, Option, Option), @@ -261,7 +266,7 @@ impl BinType { } } - pub fn ret(&self) -> Option { + pub fn result(&self) -> Option { match self { Self::Eq(_, _, ret) | Self::Add(_, _, ret) @@ -271,7 +276,7 @@ impl BinType { } } - pub fn ret_mut(&mut self) -> &mut Option { + pub fn result_mut(&mut self) -> &mut Option { match self { Self::Eq(_, _, ret) | Self::Add(_, _, ret) @@ -359,7 +364,7 @@ impl core::fmt::Display for BinType { macro_rules! bty { ($($bty:tt)+ -> $($ret:tt)+) => {{ let b = $crate::bty!($($bty)+); - b.ret_mut().replace($crate::ty!($($ret)+)); + b.result_mut().replace($crate::ty!($($ret)+)); b }}; // for pattern matching @@ -455,7 +460,7 @@ impl BinType { /// - any == ? -> ?, /// - always `bool` otherwise pub fn infer_bin_ty_eq(&self) -> Result, TypeError> { - if let Some(ret) = self.ret() { + if let Some(ret) = self.result() { return Ok(Some(ret)); } let lhs = self.lhs(); @@ -492,7 +497,7 @@ impl BinType { /// - int + int -> int /// - everything else is an unknown scalar type `_` pub fn infer_bin_ty_add(&self) -> Result, TypeError> { - if let Some(ret) = self.ret() { + if let Some(ret) = self.result() { return Ok(Some(ret)); } let lhs = self.lhs(); @@ -570,7 +575,7 @@ impl BinType { /// - bool * x -> x /// - everything else is an unknown scalar type `_` pub fn infer_bin_ty_mul(&self) -> Result, TypeError> { - if let Some(ret) = self.ret() { + if let Some(ret) = self.result() { return Ok(Some(ret)); } let lhs = self.lhs(); @@ -621,7 +626,7 @@ impl BinType { /// - a _ to any power is still a _ /// - a ? to any power is still a ? pub fn infer_bin_ty_exp(&self) -> Result, TypeError> { - if let Some(ret) = self.ret() { + if let Some(ret) = self.result() { return Ok(Some(ret)); } let lhs = self.lhs(); @@ -641,7 +646,7 @@ impl BinType { } } -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Hash, Debug, Clone, PartialEq, Eq)] pub enum Kind { Value(Option), Callable(FunctionType), From 926d666e562ee08687f5ec94141ecdf35ae3d53c Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 30 Jul 2025 16:03:11 +0200 Subject: [PATCH 05/42] feat(typing): impl Typing for Span --- typing/Cargo.toml | 3 +++ typing/src/lib.rs | 10 ++++++++++ 2 files changed, 13 insertions(+) diff --git a/typing/Cargo.toml b/typing/Cargo.toml index 442b5dd96..3095ca63a 100644 --- a/typing/Cargo.toml +++ b/typing/Cargo.toml @@ -7,5 +7,8 @@ repository.workspace = true edition.workspace = true rust-version.workspace = true +[dependencies] +miden-diagnostics = { workspace = true } + [dev-dependencies] pretty_assertions = "1.4.1" diff --git a/typing/src/lib.rs b/typing/src/lib.rs index ea1c9d2b4..15f16add7 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -2,6 +2,7 @@ mod types; use std::fmt::Debug; +use miden_diagnostics::Span; pub use types::*; pub enum TypeError { @@ -463,6 +464,15 @@ impl Typing for Option { } } +impl Typing for Span { + fn kind(&self) -> Option { + self.item.kind() + } + fn ty(&self) -> Option { + self.item.ty() + } +} + #[macro_export] macro_rules! assert_subtype { ($a:expr; !$b:expr) => { From 74af2b223d71b29c4a14d3f87aacf145253065a0 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 30 Jul 2025 16:03:56 +0200 Subject: [PATCH 06/42] feat(typing): allow forwarding of idents in sty! macro --- typing/src/types.rs | 3 +++ 1 file changed, 3 insertions(+) diff --git a/typing/src/types.rs b/typing/src/types.rs index 6d2819d44..4563630e5 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -41,6 +41,9 @@ macro_rules! sty { (int) => { Some($crate::ScalarType::Int) }; + ($sty:ident) => { + $sty + }; } /// The types of values which can be represented in an AirScript program From 5a8af93d05095c028d0346432d6912e68b84c213 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 30 Jul 2025 16:04:54 +0200 Subject: [PATCH 07/42] refactor(typing): default impl for Kind for Kind::Value Types --- typing/src/lib.rs | 13 +++---------- 1 file changed, 3 insertions(+), 10 deletions(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index 15f16add7..2ed15469d 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -32,7 +32,9 @@ pub enum TypeError { } pub trait Typing { - fn kind(&self) -> Option; + fn kind(&self) -> Option { + Some(Kind::Value(self.ty())) + } fn ty(&self) -> Option; fn shape(&self) -> Option { self.ty().and_then(|t| match t { @@ -315,9 +317,6 @@ impl core::fmt::Display for ShowOption { } impl Typing for ScalarType { - fn kind(&self) -> Option { - Some(Kind::Value(self.ty())) - } fn ty(&self) -> Option { Some(Type::Scalar(Some(*self))) } @@ -327,9 +326,6 @@ impl Typing for ScalarType { } impl Typing for Type { - fn kind(&self) -> Option { - Some(Kind::Value(self.ty())) - } fn ty(&self) -> Option { Some(*self) } @@ -377,9 +373,6 @@ impl TypeMut for BinType { } impl Typing for BinType { - fn kind(&self) -> Option { - self.as_fn().kind() - } fn ty(&self) -> Option { self.infer_ty().ok()? } From bdec887e95f8efb93b3cbe29f3cc8356f077e80f Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 30 Jul 2025 16:33:00 +0200 Subject: [PATCH 08/42] feat(typing): impl Typing for Vec --- typing/src/lib.rs | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index 2ed15469d..63aa1b44e 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -466,6 +466,22 @@ impl Typing for Span { } } +impl Typing for Vec { + fn kind(&self) -> Option { + match self.first().map(|t| t.kind())?? { + Kind::Value(ty) => ty.map(|t| Kind::Value(Some(t))), + Kind::Callable(_) => unimplemented!("A vector of callables is not supported"), + } + } + fn ty(&self) -> Option { + match self.first().map(|t| t.ty())?? { + Type::Scalar(st) => ty!(st[self.len()]), + Type::Vector(st, cols) => ty!(st[self.len(), cols]), + Type::Matrix(..) => unimplemented!("A vector of matrices is not supported"), + } + } +} + #[macro_export] macro_rules! assert_subtype { ($a:expr; !$b:expr) => { From cc98758e1574a3383cf1a70a9142a10aead0d571 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 31 Jul 2025 09:33:05 +0200 Subject: [PATCH 09/42] feat(typing): Aggregate Kind variant + rework Show + impl Typing for Box + Vec --- typing/src/lib.rs | 196 ++++++++++++++++++++++++++++++++++++++------ typing/src/types.rs | 45 ++++++++-- 2 files changed, 209 insertions(+), 32 deletions(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index 63aa1b44e..eea3aceb3 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -217,20 +217,20 @@ pub trait Typing { fn is_subtype(&self, other: &impl Typing) -> bool { self.is_subshape(other) && self.is_scalar_subtype(other) } - fn show_kind(&self) -> ShowOption { - ShowOption(self.kind()) + fn show_kind(&self) -> Show> { + Show(self.kind()) } - fn show_fn_ty(&self) -> ShowOption { + fn show_fn_ty(&self) -> Show> { match self.kind() { - Some(Kind::Callable(fn_ty)) => ShowOption(Some(fn_ty)), - _ => ShowOption(None), + Some(Kind::Callable(fn_ty)) => Show(Some(fn_ty)), + _ => Show(None), } } - fn show_ty(&self) -> ShowOption { - ShowOption(self.ty()) + fn show_ty(&self) -> Show> { + Show(self.ty()) } - fn show_scalar_ty(&self) -> ShowOption { - ShowOption(self.scalar_ty()) + fn show_scalar_ty(&self) -> Show> { + Show(self.scalar_ty()) } /// Returns the type of the current object, if it is known or can be inferred. /// If the type is not known, it returns `None`. @@ -239,6 +239,38 @@ pub trait Typing { fn infer_ty(&self) -> Result, TypeError> { Ok(self.ty()) } + fn lowest_common_supertype(&self, other: &impl Typing) -> Option { + match (self.ty(), other.ty()) { + (ty!(?), _) | (_, ty!(?)) => ty!(?), + (ty!(_), Some(Type::Scalar(_))) | (Some(Type::Scalar(_)), ty!(_)) => ty!(_), + (Some(Type::Vector(sty!(_), llen)), Some(Type::Vector(_, rlen))) + | (Some(Type::Vector(_, llen)), Some(Type::Vector(sty!(_), rlen))) => { + ty!(_[llen.max(rlen)]) + }, + (Some(Type::Matrix(sty!(_), lrows, lcols)), Some(Type::Matrix(_, rrows, rcols))) + | (Some(Type::Matrix(_, lrows, lcols)), Some(Type::Matrix(sty!(_), rrows, rcols))) => { + ty!(_[lrows.max(rrows), lcols.max(rcols)]) + }, + (lhs, rhs) if lhs.is_subtype(&rhs) => rhs, + (lhs, rhs) if rhs.is_subtype(&lhs) => lhs, + (ty!(int), ty!(bool)) | (ty!(bool), ty!(int)) => ty!(felt), + (Some(Type::Vector(sty!(int), llen)), Some(Type::Vector(sty!(bool), rlen))) + | (Some(Type::Vector(sty!(bool), llen)), Some(Type::Vector(sty!(int), rlen))) => { + ty!(felt[core::cmp::max(llen, rlen)]) + }, + ( + Some(Type::Matrix(sty!(int), lrows, lcols)), + Some(Type::Matrix(sty!(bool), rrows, rcols)), + ) + | ( + Some(Type::Matrix(sty!(bool), lrows, lcols)), + Some(Type::Matrix(sty!(int), rrows, rcols)), + ) => { + ty!(felt[core::cmp::max(lrows, rrows), core::cmp::max(lcols, rcols)]) + }, + _ => None, + } + } } pub trait ScalarTypeMut: Typing { @@ -278,9 +310,9 @@ pub trait TypeMut: Typing + ScalarTypeMut { } #[derive(Clone, Debug, PartialEq, Eq)] -pub struct ShowOption(Option); +pub struct Show(T); -impl core::fmt::Display for ShowOption { +impl core::fmt::Display for Show> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match &self.0 { None => f.write_str("!"), @@ -289,7 +321,7 @@ impl core::fmt::Display for ShowOption { } } -impl core::fmt::Display for ShowOption { +impl core::fmt::Display for Show> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match &self.0 { None => f.write_str("?"), @@ -298,7 +330,7 @@ impl core::fmt::Display for ShowOption { } } -impl core::fmt::Display for ShowOption { +impl core::fmt::Display for Show> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match &self.0 { None => f.write_str("?"), @@ -307,7 +339,7 @@ impl core::fmt::Display for ShowOption { } } -impl core::fmt::Display for ShowOption { +impl core::fmt::Display for Show> { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { match &self.0 { None => f.write_str("_"), @@ -316,6 +348,40 @@ impl core::fmt::Display for ShowOption { } } +impl core::fmt::Display for Show> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "[{}]", + self.0.iter().map(|t| t.show_ty().to_string()).collect::>().join(", ") + ) + } +} + +impl Typing for Show { + fn kind(&self) -> Option { + self.0.kind() + } + fn ty(&self) -> Option { + self.0.ty() + } + fn scalar_ty(&self) -> Option { + self.0.scalar_ty() + } + fn show_kind(&self) -> Show> { + self.0.show_kind() + } + fn show_fn_ty(&self) -> Show> { + self.0.show_fn_ty() + } + fn show_ty(&self) -> Show> { + self.0.show_ty() + } + fn show_scalar_ty(&self) -> Show> { + self.0.show_scalar_ty() + } +} + impl Typing for ScalarType { fn ty(&self) -> Option { Some(Type::Scalar(Some(*self))) @@ -396,6 +462,7 @@ impl ScalarTypeMut for Kind { fn scalar_ty_mut(&mut self) -> &mut Option { match self { Kind::Value(ty) => ty.scalar_ty_mut(), + Kind::Aggregate(_) => panic!("Cannot mutate scalar type of an aggregate kind"), Kind::Callable(_) => panic!("Cannot mutate scalar type of a callable kind"), } } @@ -405,6 +472,7 @@ impl TypeMut for Kind { fn ty_mut(&mut self) -> &mut Option { match self { Kind::Value(ty) => ty, + Kind::Aggregate(_) => panic!("Cannot mutate type of an aggregate kind"), Kind::Callable(_) => panic!("Cannot mutate type of a callable kind"), } } @@ -415,10 +483,26 @@ impl Typing for Kind { Some(self.clone()) } fn ty(&self) -> Option { - let Kind::Value(ty) = self else { - return None; - }; - *ty + match self { + Kind::Value(ty) => *ty, + Kind::Aggregate(a) => { + let mut inner_ty = a.first().and_then(|t| t.ty()); + for item in a.iter().skip(1) { + let item_ty = item.ty(); + inner_ty = item_ty.lowest_common_supertype(&inner_ty); + } + match inner_ty { + None => None, + Some(Type::Scalar(st)) => ty!(st[a.len()]), + Some(Type::Vector(st, cols)) => ty!(st[a.len(), cols]), + Some(Type::Matrix(..)) => { + // An aggregate of matrices is not supported + None + }, + } + }, + Kind::Callable(_) => None, + } } } @@ -457,6 +541,15 @@ impl Typing for Option { } } +impl Typing for Box { + fn kind(&self) -> Option { + T::kind(self) + } + fn ty(&self) -> Option { + T::ty(self) + } +} + impl Typing for Span { fn kind(&self) -> Option { self.item.kind() @@ -468,17 +561,11 @@ impl Typing for Span { impl Typing for Vec { fn kind(&self) -> Option { - match self.first().map(|t| t.kind())?? { - Kind::Value(ty) => ty.map(|t| Kind::Value(Some(t))), - Kind::Callable(_) => unimplemented!("A vector of callables is not supported"), - } + let agg = self.iter().map(|t| t.kind().map(Box::new)).collect(); + Some(Kind::Aggregate(agg)) } fn ty(&self) -> Option { - match self.first().map(|t| t.ty())?? { - Type::Scalar(st) => ty!(st[self.len()]), - Type::Vector(st, cols) => ty!(st[self.len(), cols]), - Type::Matrix(..) => unimplemented!("A vector of matrices is not supported"), - } + self.kind().ty() } } @@ -731,4 +818,59 @@ mod tests { assert_subtype!(ty!(int[3, 4]); !ty!(bool[3, 4])); assert_subtype!(ty!(int[3, 4]); ty!(int[3, 4])); } + + macro_rules! assert_ty_eq { + ($a:expr, $b:expr) => {{ + eprintln!("{}: {} == {}", $a.show_kind(), $a.show_ty(), $b.show_ty()); + assert_eq!( + $a.ty(), + $b, + "Expected {} to be equal to {}, but it was not", + $a.ty().show_ty(), + $b.show_ty(), + ); + }}; + } + #[track_caller] + fn assert_tys_eq_with_rev(a: Vec, b: Option) { + assert_ty_eq!(a, b); + assert_ty_eq!(a.iter().rev().cloned().collect::>(), b); + } + #[test] + fn test_vec_typing() { + assert_ty_eq!(vec![ty!(felt), ty!(felt), ty!(felt)], ty!(felt[3])); + assert_tys_eq_with_rev(tys!([int, felt]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([bool, int]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([_, int]), ty!(_[2])); + assert_tys_eq_with_rev(tys!([?, int]), ty!(?)); + assert_tys_eq_with_rev(tys!([felt[5], felt[5]]), ty!(felt[2, 5])); + assert_tys_eq_with_rev(tys!([int[5], felt[5]]), ty!(felt[2, 5])); + assert_tys_eq_with_rev(tys!([bool[5], int[5]]), ty!(felt[2, 5])); + assert_tys_eq_with_rev(tys!([_[5], int[5]]), ty!(_[2, 5])); + assert_tys_eq_with_rev(tys!([bool[3], int[8]]), ty!(felt[2, 8])); + assert_tys_eq_with_rev(tys!([_[3], int[8]]), ty!(_[2, 8])); + assert_tys_eq_with_rev(tys!([?, int[5]]), ty!(?)); + assert_tys_eq_with_rev(tys!([int[5], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([bool[5, 2], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([int[5], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([bool[5, 2], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([int[5, 2], _]), ty!(?)); + assert_tys_eq_with_rev(tys!([int[5, 2]]), ty!(?)); + assert_tys_eq_with_rev(tys!([int, felt]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([bool, int]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([_, int]), ty!(_[2])); + assert_tys_eq_with_rev(tys!([?, int]), ty!(?)); + + assert_tys_eq_with_rev( + vec![tys!([int, felt]), tys!([int, felt]), tys!([int, felt])], + ty!(felt[3, 2]), + ); + assert_tys_eq_with_rev( + vec![tys!([bool, int]), tys!([bool, int]), tys!([bool, int])], + ty!(felt[3, 2]), + ); + assert_tys_eq_with_rev(vec![tys!([_, int]), tys!([_, int]), tys!([_, int])], ty!(_[3, 2])); + assert_tys_eq_with_rev(vec![tys!([?, int]), tys!([?, int]), tys!([?, int])], ty!(?)); + assert_tys_eq_with_rev(tys!([felt[5], int[5], bool[5]]), ty!(felt[3, 5])); + } } diff --git a/typing/src/types.rs b/typing/src/types.rs index 4563630e5..3b3cd7c93 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -110,9 +110,9 @@ macro_rules! ty { }; } -pub struct Push(Vec>); -impl Push { - pub fn push(mut self, ty: Option) -> Self { +pub struct Push(pub Vec); +impl Push { + pub fn push(mut self, ty: T) -> Self { self.0.push(ty); self } @@ -126,8 +126,8 @@ macro_rules! tys { (RES: $res:expr; ) => { $res }; - (RES: $res:expr; ?) => { - tys!(RES: $crate::Push::push($res, $crate::ty!(?));) + (RES: $res:expr; ? $(, $($rest:tt)+)?) => { + tys!(RES: $crate::Push::push($res, $crate::ty!(?)); $($($rest)+)?) }; (RES: $res:expr; _$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { tys!(RES: $crate::Push::push($res, $crate::ty!(_$([$($spec)+])?)); $($($rest)+)?) @@ -137,6 +137,25 @@ macro_rules! tys { }; } +#[macro_export] +macro_rules! kinds { + ([$($args:tt)+]) => { + kinds!(RES: Push(vec![]); $($args)+).0 + }; + (RES: $res:expr; ) => { + $res + }; + (RES: $res:expr; ?) => { + kinds!(RES: $crate::Push::push($res, $crate::kind!(?));) + }; + (RES: $res:expr; _$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { + kinds!(RES: $crate::Push::push($res, $crate::kind!(_$([$($spec)+])?)); $($($rest)+)?) + }; + (RES: $res:expr; $name:ident$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { + kinds!(RES: $crate::Push::push($res, $crate::kind!($name$([$($spec)+])?)); $($($rest)+)?) + }; +} + #[macro_export] macro_rules! tty { ([$($n1:ident$([$l1:expr])?),*]) => { @@ -652,6 +671,7 @@ impl BinType { #[derive(Hash, Debug, Clone, PartialEq, Eq)] pub enum Kind { Value(Option), + Aggregate(Vec>>), Callable(FunctionType), } @@ -659,6 +679,18 @@ impl core::fmt::Display for Kind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Value(ty) => write!(f, "{}", ty.show_ty()), + Self::Aggregate(tys) => { + write!( + f, + "[{}]", + tys.iter() + .map(|ty| ty + .as_ref() + .map_or("?".to_string(), |k| k.show_kind().to_string())) + .collect::>() + .join(", ") + ) + }, Self::Callable(fty) => write!(f, "{}", fty.show_fn_ty()), } } @@ -672,6 +704,9 @@ macro_rules! kind { (fn ($($args:tt)*) -> $($ret:tt)+) => { $crate::Kind::Callable($crate::fty!(fn ($($args)*) -> $($ret)+)) }; + ([$($spec:tt)+]) => { + $crate::Kind::Aggregate(kinds!([$($spec)+])) + }; ($($spec:tt)+) => { $crate::Kind::Value($crate::ty!($($spec)+)) }; From 41a3ff4e62f1265d3eb46a2b8d8acd33780c20db Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 31 Jul 2025 16:14:47 +0200 Subject: [PATCH 10/42] refactor(typing): rename int to uint --- typing/src/lib.rs | 263 ++++++++++++++++++++++---------------------- typing/src/types.rs | 116 +++++++++---------- 2 files changed, 192 insertions(+), 187 deletions(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index eea3aceb3..0971e56a4 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -65,7 +65,7 @@ pub trait Typing { matches!(self.scalar_ty(), sty!(bool)) } fn is_scalar_int(&self) -> bool { - matches!(self.scalar_ty(), sty!(int)) + matches!(self.scalar_ty(), sty!(uint)) } fn is_scalar(&self) -> bool { matches!(self.ty(), Some(Type::Scalar(_))) @@ -149,30 +149,30 @@ pub trait Typing { /// Felt type /// bool: ScalarType::Bool /// Boolean type - /// int: ScalarType::Int + /// uint: ScalarType::Int /// Integer type /// /// Subtyping rules: /// - felt > bool > _ - /// - felt > int > _ + /// - felt > uint > _ /// /// Which means: /// - `_` is a subtype of all scalar types /// - `bool` is a subtype of `felt`: a `bool` is a `felt with a `is_bool` property - /// - `int` is a subtype of `felt`: a `int` is a `felt` with the `constant` property + /// - `uint` is a subtype of `felt`: a `uint` is a `felt` with the `constant` property /// - /// self\\other || felt | bool | int | _ | - /// ============||======|======|=====|===| - /// felt || y | n | n | n | - /// bool || y | y | n | n | - /// int || y | n | y | n | - /// _ || y | y | y | y | + /// self\\other || felt | bool | uint | _ | + /// ============||======|======|======|===| + /// felt || y | n | n | n | + /// bool || y | y | n | n | + /// uint || y | n | y | n | + /// _ || y | y | y | y | fn is_scalar_subtype(&self, other: &impl Typing) -> bool { !matches!( (self.scalar_ty(), other.scalar_ty()), - (sty!(felt), sty!(bool) | sty!(int) | sty!(_)) - | (sty!(bool), sty!(int) | sty!(_)) - | (sty!(int), sty!(bool) | sty!(_)) + (sty!(felt), sty!(bool) | sty!(uint) | sty!(_)) + | (sty!(bool), sty!(uint) | sty!(_)) + | (sty!(uint), sty!(bool) | sty!(_)) ) } /// Returns true if `self` is a subtype of `other` @@ -185,7 +185,7 @@ pub trait Typing { /// Felt type /// bool: Type::Scalar(Some(ScalarType::Bool)) /// Boolean type - /// int: Type::Scalar(Some(ScalarType::Int)) + /// uint: Type::Scalar(Some(ScalarType::Int)) /// Integer type /// sty[len]: Type::Vector(Some(sty), len) /// Vector of length `len` with scalar type `sty` @@ -194,23 +194,23 @@ pub trait Typing { /// /// Subtyping rules: /// ? > _ > felt > bool - /// ... > felt > int + /// ... > felt > uint /// ? > _[l] > felt[l] > bool[l] - /// ... > felt[l] > int[l] + /// ... > felt[l] > uint[l] /// ? > _[r, c] > felt[r, c] > bool[r, c] - /// ... > felt[r, c] > int[r, c] + /// ... > felt[r, c] > uint[r, c] /// Assuming shapes are compatible, this function checks if the scalar types, /// with the added case of `?`, which all types are subtypes of. /// See [Typing::is_scalar_subtype] for a more detailed explanation /// of the subtyping rules of scalar types. /// - /// self\\other || felt | bool | int | _ | ? | - /// ============||======|======|=====|===|===| - /// felt ||[ y | n | n | n]| n | - /// bool ||[ y | y | n | n]| n | - /// int ||[ y | n | y | n]| n | - /// _ ||[ y | y | y | y]| n | - /// ? || y | y | y | y | y | + /// self\\other || felt | bool | uint | _ | ? | + /// ============||======|======|======|===|===| + /// felt ||[ y | n | n | n]| n | + /// bool ||[ y | y | n | n]| n | + /// uint ||[ y | n | y | n]| n | + /// _ ||[ y | y | y | y]| n | + /// ? || y | y | y | y | y | /// /// = self.is_scalar_subtype(other) | self == ? /// [...] Denotes the result of the [Typing::is_scalar_subtype] method. @@ -253,18 +253,18 @@ pub trait Typing { }, (lhs, rhs) if lhs.is_subtype(&rhs) => rhs, (lhs, rhs) if rhs.is_subtype(&lhs) => lhs, - (ty!(int), ty!(bool)) | (ty!(bool), ty!(int)) => ty!(felt), - (Some(Type::Vector(sty!(int), llen)), Some(Type::Vector(sty!(bool), rlen))) - | (Some(Type::Vector(sty!(bool), llen)), Some(Type::Vector(sty!(int), rlen))) => { + (ty!(uint), ty!(bool)) | (ty!(bool), ty!(uint)) => ty!(felt), + (Some(Type::Vector(sty!(uint), llen)), Some(Type::Vector(sty!(bool), rlen))) + | (Some(Type::Vector(sty!(bool), llen)), Some(Type::Vector(sty!(uint), rlen))) => { ty!(felt[core::cmp::max(llen, rlen)]) }, ( - Some(Type::Matrix(sty!(int), lrows, lcols)), + Some(Type::Matrix(sty!(uint), lrows, lcols)), Some(Type::Matrix(sty!(bool), rrows, rcols)), ) | ( Some(Type::Matrix(sty!(bool), lrows, lcols)), - Some(Type::Matrix(sty!(int), rrows, rcols)), + Some(Type::Matrix(sty!(uint), rrows, rcols)), ) => { ty!(felt[core::cmp::max(lrows, rrows), core::cmp::max(lcols, rcols)]) }, @@ -614,24 +614,24 @@ mod tests { assert_eq!(ty!(felt).scalar_ty(), sty!(felt)); assert_eq!(ty!(bool).ty(), Some(Type::Scalar(sty!(bool)))); assert_eq!(ty!(bool).scalar_ty(), sty!(bool)); - assert_eq!(ty!(int).ty(), Some(Type::Scalar(sty!(int)))); - assert_eq!(ty!(int).scalar_ty(), sty!(int)); + assert_eq!(ty!(uint).ty(), Some(Type::Scalar(sty!(uint)))); + assert_eq!(ty!(uint).scalar_ty(), sty!(uint)); assert_eq!(ty!(_[5]).ty(), Some(Type::Vector(sty!(_), 5))); assert_eq!(ty!(_[5]).scalar_ty(), sty!(_)); assert_eq!(ty!(felt[5]).ty(), Some(Type::Vector(sty!(felt), 5))); assert_eq!(ty!(felt[5]).scalar_ty(), sty!(felt)); assert_eq!(ty!(bool[5]).ty(), Some(Type::Vector(sty!(bool), 5))); assert_eq!(ty!(bool[5]).scalar_ty(), sty!(bool)); - assert_eq!(ty!(int[5]).ty(), Some(Type::Vector(sty!(int), 5))); - assert_eq!(ty!(int[5]).scalar_ty(), sty!(int)); + assert_eq!(ty!(uint[5]).ty(), Some(Type::Vector(sty!(uint), 5))); + assert_eq!(ty!(uint[5]).scalar_ty(), sty!(uint)); assert_eq!(ty!(_[3, 4]).ty(), Some(Type::Matrix(sty!(_), 3, 4))); assert_eq!(ty!(_[3, 4]).scalar_ty(), sty!(_)); assert_eq!(ty!(felt[3, 4]).ty(), Some(Type::Matrix(sty!(felt), 3, 4))); assert_eq!(ty!(felt[3, 4]).scalar_ty(), sty!(felt)); assert_eq!(ty!(bool[3, 4]).ty(), Some(Type::Matrix(sty!(bool), 3, 4))); assert_eq!(ty!(bool[3, 4]).scalar_ty(), sty!(bool)); - assert_eq!(ty!(int[3, 4]).ty(), Some(Type::Matrix(sty!(int), 3, 4))); - assert_eq!(ty!(int[3, 4]).scalar_ty(), sty!(int)); + assert_eq!(ty!(uint[3, 4]).ty(), Some(Type::Matrix(sty!(uint), 3, 4))); + assert_eq!(ty!(uint[3, 4]).scalar_ty(), sty!(uint)); } #[test] @@ -640,183 +640,183 @@ mod tests { assert_subtype!(ty!(?); ty!(_)); assert_subtype!(ty!(?); ty!(felt)); assert_subtype!(ty!(?); ty!(bool)); - assert_subtype!(ty!(?); ty!(int)); + assert_subtype!(ty!(?); ty!(uint)); assert_subtype!(ty!(?); ty!(_[5])); assert_subtype!(ty!(?); ty!(felt[5])); assert_subtype!(ty!(?); ty!(bool[5])); - assert_subtype!(ty!(?); ty!(int[5])); + assert_subtype!(ty!(?); ty!(uint[5])); assert_subtype!(ty!(?); ty!(_[3, 4])); assert_subtype!(ty!(?); ty!(felt[3, 4])); assert_subtype!(ty!(?); ty!(bool[3, 4])); - assert_subtype!(ty!(?); ty!(int[3, 4])); + assert_subtype!(ty!(?); ty!(uint[3, 4])); assert_subtype!(ty!(_); !ty!(?)); assert_subtype!(ty!(_); ty!(_)); assert_subtype!(ty!(_); ty!(felt)); assert_subtype!(ty!(_); ty!(bool)); - assert_subtype!(ty!(_); ty!(int)); + assert_subtype!(ty!(_); ty!(uint)); assert_subtype!(ty!(_); !ty!(_[5])); assert_subtype!(ty!(_); !ty!(felt[5])); assert_subtype!(ty!(_); !ty!(bool[5])); - assert_subtype!(ty!(_); !ty!(int[5])); + assert_subtype!(ty!(_); !ty!(uint[5])); assert_subtype!(ty!(_); !ty!(_[3, 4])); assert_subtype!(ty!(_); !ty!(felt[3, 4])); assert_subtype!(ty!(_); !ty!(bool[3, 4])); - assert_subtype!(ty!(_); !ty!(int[3, 4])); + assert_subtype!(ty!(_); !ty!(uint[3, 4])); assert_subtype!(ty!(felt); !ty!(?)); assert_subtype!(ty!(felt); !ty!(_)); assert_subtype!(ty!(felt); ty!(felt)); assert_subtype!(ty!(felt); !ty!(bool)); - assert_subtype!(ty!(felt); !ty!(int)); + assert_subtype!(ty!(felt); !ty!(uint)); assert_subtype!(ty!(felt); !ty!(_[5])); assert_subtype!(ty!(felt); !ty!(felt[5])); assert_subtype!(ty!(felt); !ty!(bool[5])); - assert_subtype!(ty!(felt); !ty!(int[5])); + assert_subtype!(ty!(felt); !ty!(uint[5])); assert_subtype!(ty!(felt); !ty!(_[3, 4])); assert_subtype!(ty!(felt); !ty!(felt[3, 4])); assert_subtype!(ty!(felt); !ty!(bool[3, 4])); - assert_subtype!(ty!(felt); !ty!(int[3, 4])); + assert_subtype!(ty!(felt); !ty!(uint[3, 4])); assert_subtype!(ty!(bool); !ty!(?)); assert_subtype!(ty!(bool); !ty!(_)); assert_subtype!(ty!(bool); ty!(felt)); assert_subtype!(ty!(bool); ty!(bool)); - assert_subtype!(ty!(bool); !ty!(int)); + assert_subtype!(ty!(bool); !ty!(uint)); assert_subtype!(ty!(bool); !ty!(_[5])); assert_subtype!(ty!(bool); !ty!(felt[5])); assert_subtype!(ty!(bool); !ty!(bool[5])); - assert_subtype!(ty!(bool); !ty!(int[5])); + assert_subtype!(ty!(bool); !ty!(uint[5])); assert_subtype!(ty!(bool); !ty!(_[3, 4])); assert_subtype!(ty!(bool); !ty!(felt[3, 4])); assert_subtype!(ty!(bool); !ty!(bool[3, 4])); - assert_subtype!(ty!(bool); !ty!(int[3, 4])); - - assert_subtype!(ty!(int); !ty!(?)); - assert_subtype!(ty!(int); !ty!(_)); - assert_subtype!(ty!(int); ty!(felt)); - assert_subtype!(ty!(int); !ty!(bool)); - assert_subtype!(ty!(int); ty!(int)); - assert_subtype!(ty!(int); !ty!(_[5])); - assert_subtype!(ty!(int); !ty!(felt[5])); - assert_subtype!(ty!(int); !ty!(bool[5])); - assert_subtype!(ty!(int); !ty!(int[5])); - assert_subtype!(ty!(int); !ty!(_[3, 4])); - assert_subtype!(ty!(int); !ty!(felt[3, 4])); - assert_subtype!(ty!(int); !ty!(bool[3, 4])); - assert_subtype!(ty!(int); !ty!(int[3, 4])); + assert_subtype!(ty!(bool); !ty!(uint[3, 4])); + + assert_subtype!(ty!(uint); !ty!(?)); + assert_subtype!(ty!(uint); !ty!(_)); + assert_subtype!(ty!(uint); ty!(felt)); + assert_subtype!(ty!(uint); !ty!(bool)); + assert_subtype!(ty!(uint); ty!(uint)); + assert_subtype!(ty!(uint); !ty!(_[5])); + assert_subtype!(ty!(uint); !ty!(felt[5])); + assert_subtype!(ty!(uint); !ty!(bool[5])); + assert_subtype!(ty!(uint); !ty!(uint[5])); + assert_subtype!(ty!(uint); !ty!(_[3, 4])); + assert_subtype!(ty!(uint); !ty!(felt[3, 4])); + assert_subtype!(ty!(uint); !ty!(bool[3, 4])); + assert_subtype!(ty!(uint); !ty!(uint[3, 4])); assert_subtype!(ty!(_[5]); !ty!(?)); assert_subtype!(ty!(_[5]); !ty!(_)); assert_subtype!(ty!(_[5]); !ty!(felt)); assert_subtype!(ty!(_[5]); !ty!(bool)); - assert_subtype!(ty!(_[5]); !ty!(int)); + assert_subtype!(ty!(_[5]); !ty!(uint)); assert_subtype!(ty!(_[5]); ty!(_[5])); assert_subtype!(ty!(_[5]); ty!(felt[5])); assert_subtype!(ty!(_[5]); ty!(bool[5])); - assert_subtype!(ty!(_[5]); ty!(int[5])); + assert_subtype!(ty!(_[5]); ty!(uint[5])); assert_subtype!(ty!(_[5]); !ty!(_[3, 4])); assert_subtype!(ty!(_[5]); !ty!(felt[3, 4])); assert_subtype!(ty!(_[5]); !ty!(bool[3, 4])); - assert_subtype!(ty!(_[5]); !ty!(int[3, 4])); + assert_subtype!(ty!(_[5]); !ty!(uint[3, 4])); assert_subtype!(ty!(felt[5]); !ty!(?)); assert_subtype!(ty!(felt[5]); !ty!(_)); assert_subtype!(ty!(felt[5]); !ty!(felt)); assert_subtype!(ty!(felt[5]); !ty!(bool)); - assert_subtype!(ty!(felt[5]); !ty!(int)); + assert_subtype!(ty!(felt[5]); !ty!(uint)); assert_subtype!(ty!(felt[5]); !ty!(_[5])); assert_subtype!(ty!(felt[5]); ty!(felt[5])); assert_subtype!(ty!(felt[5]); !ty!(bool[5])); - assert_subtype!(ty!(felt[5]); !ty!(int[5])); + assert_subtype!(ty!(felt[5]); !ty!(uint[5])); assert_subtype!(ty!(felt[5]); !ty!(_[3, 4])); assert_subtype!(ty!(felt[5]); !ty!(felt[3, 4])); assert_subtype!(ty!(felt[5]); !ty!(bool[3, 4])); - assert_subtype!(ty!(felt[5]); !ty!(int[3, 4])); + assert_subtype!(ty!(felt[5]); !ty!(uint[3, 4])); assert_subtype!(ty!(bool[5]); !ty!(?)); assert_subtype!(ty!(bool[5]); !ty!(_)); assert_subtype!(ty!(bool[5]); !ty!(felt)); assert_subtype!(ty!(bool[5]); !ty!(bool)); - assert_subtype!(ty!(bool[5]); !ty!(int)); + assert_subtype!(ty!(bool[5]); !ty!(uint)); assert_subtype!(ty!(bool[5]); !ty!(_[5])); assert_subtype!(ty!(bool[5]); ty!(felt[5])); assert_subtype!(ty!(bool[5]); ty!(bool[5])); - assert_subtype!(ty!(bool[5]); !ty!(int[5])); + assert_subtype!(ty!(bool[5]); !ty!(uint[5])); assert_subtype!(ty!(bool[5]); !ty!(_[3, 4])); assert_subtype!(ty!(bool[5]); !ty!(felt[3, 4])); assert_subtype!(ty!(bool[5]); !ty!(bool[3, 4])); - assert_subtype!(ty!(bool[5]); !ty!(int[3, 4])); - - assert_subtype!(ty!(int[5]); !ty!(?)); - assert_subtype!(ty!(int[5]); !ty!(_)); - assert_subtype!(ty!(int[5]); !ty!(felt)); - assert_subtype!(ty!(int[5]); !ty!(bool)); - assert_subtype!(ty!(int[5]); !ty!(int)); - assert_subtype!(ty!(int[5]); !ty!(_[5])); - assert_subtype!(ty!(int[5]); ty!(felt[5])); - assert_subtype!(ty!(int[5]); !ty!(bool[5])); - assert_subtype!(ty!(int[5]); ty!(int[5])); - assert_subtype!(ty!(int[5]); !ty!(_[3, 4])); - assert_subtype!(ty!(int[5]); !ty!(felt[3, 4])); - assert_subtype!(ty!(int[5]); !ty!(bool[3, 4])); - assert_subtype!(ty!(int[5]); !ty!(int[3, 4])); + assert_subtype!(ty!(bool[5]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(uint[5]); !ty!(?)); + assert_subtype!(ty!(uint[5]); !ty!(_)); + assert_subtype!(ty!(uint[5]); !ty!(felt)); + assert_subtype!(ty!(uint[5]); !ty!(bool)); + assert_subtype!(ty!(uint[5]); !ty!(uint)); + assert_subtype!(ty!(uint[5]); !ty!(_[5])); + assert_subtype!(ty!(uint[5]); ty!(felt[5])); + assert_subtype!(ty!(uint[5]); !ty!(bool[5])); + assert_subtype!(ty!(uint[5]); ty!(uint[5])); + assert_subtype!(ty!(uint[5]); !ty!(_[3, 4])); + assert_subtype!(ty!(uint[5]); !ty!(felt[3, 4])); + assert_subtype!(ty!(uint[5]); !ty!(bool[3, 4])); + assert_subtype!(ty!(uint[5]); !ty!(uint[3, 4])); assert_subtype!(ty!(_[3, 4]); !ty!(?)); assert_subtype!(ty!(_[3, 4]); !ty!(_)); assert_subtype!(ty!(_[3, 4]); !ty!(felt)); assert_subtype!(ty!(_[3, 4]); !ty!(bool)); - assert_subtype!(ty!(_[3, 4]); !ty!(int)); + assert_subtype!(ty!(_[3, 4]); !ty!(uint)); assert_subtype!(ty!(_[3, 4]); !ty!(_[5])); assert_subtype!(ty!(_[3, 4]); !ty!(felt[5])); assert_subtype!(ty!(_[3, 4]); !ty!(bool[5])); - assert_subtype!(ty!(_[3, 4]); !ty!(int[5])); + assert_subtype!(ty!(_[3, 4]); !ty!(uint[5])); assert_subtype!(ty!(_[3, 4]); ty!(_[3, 4])); assert_subtype!(ty!(_[3, 4]); ty!(felt[3, 4])); assert_subtype!(ty!(_[3, 4]); ty!(bool[3, 4])); - assert_subtype!(ty!(_[3, 4]); ty!(int[3, 4])); + assert_subtype!(ty!(_[3, 4]); ty!(uint[3, 4])); assert_subtype!(ty!(felt[3, 4]); !ty!(?)); assert_subtype!(ty!(felt[3, 4]); !ty!(_)); assert_subtype!(ty!(felt[3, 4]); !ty!(felt)); assert_subtype!(ty!(felt[3, 4]); !ty!(bool)); - assert_subtype!(ty!(felt[3, 4]); !ty!(int)); + assert_subtype!(ty!(felt[3, 4]); !ty!(uint)); assert_subtype!(ty!(felt[3, 4]); !ty!(_[5])); assert_subtype!(ty!(felt[3, 4]); !ty!(felt[5])); assert_subtype!(ty!(felt[3, 4]); !ty!(bool[5])); - assert_subtype!(ty!(felt[3, 4]); !ty!(int[5])); + assert_subtype!(ty!(felt[3, 4]); !ty!(uint[5])); assert_subtype!(ty!(felt[3, 4]); !ty!(_[3, 4])); assert_subtype!(ty!(felt[3, 4]); ty!(felt[3, 4])); assert_subtype!(ty!(felt[3, 4]); !ty!(bool[3, 4])); - assert_subtype!(ty!(felt[3, 4]); !ty!(int[3, 4])); + assert_subtype!(ty!(felt[3, 4]); !ty!(uint[3, 4])); assert_subtype!(ty!(bool[3, 4]); !ty!(?)); assert_subtype!(ty!(bool[3, 4]); !ty!(_)); assert_subtype!(ty!(bool[3, 4]); !ty!(felt)); assert_subtype!(ty!(bool[3, 4]); !ty!(bool)); - assert_subtype!(ty!(bool[3, 4]); !ty!(int)); + assert_subtype!(ty!(bool[3, 4]); !ty!(uint)); assert_subtype!(ty!(bool[3, 4]); !ty!(_[5])); assert_subtype!(ty!(bool[3, 4]); !ty!(felt[5])); assert_subtype!(ty!(bool[3, 4]); !ty!(bool[5])); - assert_subtype!(ty!(bool[3, 4]); !ty!(int[5])); + assert_subtype!(ty!(bool[3, 4]); !ty!(uint[5])); assert_subtype!(ty!(bool[3, 4]); !ty!(_[3, 4])); assert_subtype!(ty!(bool[3, 4]); ty!(felt[3, 4])); assert_subtype!(ty!(bool[3, 4]); ty!(bool[3, 4])); - assert_subtype!(ty!(bool[3, 4]); !ty!(int[3, 4])); - - assert_subtype!(ty!(int[3, 4]); !ty!(?)); - assert_subtype!(ty!(int[3, 4]); !ty!(_)); - assert_subtype!(ty!(int[3, 4]); !ty!(felt)); - assert_subtype!(ty!(int[3, 4]); !ty!(bool)); - assert_subtype!(ty!(int[3, 4]); !ty!(int)); - assert_subtype!(ty!(int[3, 4]); !ty!(_[5])); - assert_subtype!(ty!(int[3, 4]); !ty!(felt[5])); - assert_subtype!(ty!(int[3, 4]); !ty!(bool[5])); - assert_subtype!(ty!(int[3, 4]); !ty!(int[5])); - assert_subtype!(ty!(int[3, 4]); !ty!(_[3, 4])); - assert_subtype!(ty!(int[3, 4]); ty!(felt[3, 4])); - assert_subtype!(ty!(int[3, 4]); !ty!(bool[3, 4])); - assert_subtype!(ty!(int[3, 4]); ty!(int[3, 4])); + assert_subtype!(ty!(bool[3, 4]); !ty!(uint[3, 4])); + + assert_subtype!(ty!(uint[3, 4]); !ty!(?)); + assert_subtype!(ty!(uint[3, 4]); !ty!(_)); + assert_subtype!(ty!(uint[3, 4]); !ty!(felt)); + assert_subtype!(ty!(uint[3, 4]); !ty!(bool)); + assert_subtype!(ty!(uint[3, 4]); !ty!(uint)); + assert_subtype!(ty!(uint[3, 4]); !ty!(_[5])); + assert_subtype!(ty!(uint[3, 4]); !ty!(felt[5])); + assert_subtype!(ty!(uint[3, 4]); !ty!(bool[5])); + assert_subtype!(ty!(uint[3, 4]); !ty!(uint[5])); + assert_subtype!(ty!(uint[3, 4]); !ty!(_[3, 4])); + assert_subtype!(ty!(uint[3, 4]); ty!(felt[3, 4])); + assert_subtype!(ty!(uint[3, 4]); !ty!(bool[3, 4])); + assert_subtype!(ty!(uint[3, 4]); ty!(uint[3, 4])); } macro_rules! assert_ty_eq { @@ -839,38 +839,41 @@ mod tests { #[test] fn test_vec_typing() { assert_ty_eq!(vec![ty!(felt), ty!(felt), ty!(felt)], ty!(felt[3])); - assert_tys_eq_with_rev(tys!([int, felt]), ty!(felt[2])); - assert_tys_eq_with_rev(tys!([bool, int]), ty!(felt[2])); - assert_tys_eq_with_rev(tys!([_, int]), ty!(_[2])); - assert_tys_eq_with_rev(tys!([?, int]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint, felt]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([bool, uint]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([_, uint]), ty!(_[2])); + assert_tys_eq_with_rev(tys!([?, uint]), ty!(?)); assert_tys_eq_with_rev(tys!([felt[5], felt[5]]), ty!(felt[2, 5])); - assert_tys_eq_with_rev(tys!([int[5], felt[5]]), ty!(felt[2, 5])); - assert_tys_eq_with_rev(tys!([bool[5], int[5]]), ty!(felt[2, 5])); - assert_tys_eq_with_rev(tys!([_[5], int[5]]), ty!(_[2, 5])); - assert_tys_eq_with_rev(tys!([bool[3], int[8]]), ty!(felt[2, 8])); - assert_tys_eq_with_rev(tys!([_[3], int[8]]), ty!(_[2, 8])); - assert_tys_eq_with_rev(tys!([?, int[5]]), ty!(?)); - assert_tys_eq_with_rev(tys!([int[5], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5], felt[5]]), ty!(felt[2, 5])); + assert_tys_eq_with_rev(tys!([bool[5], uint[5]]), ty!(felt[2, 5])); + assert_tys_eq_with_rev(tys!([_[5], uint[5]]), ty!(_[2, 5])); + assert_tys_eq_with_rev(tys!([bool[3], uint[8]]), ty!(felt[2, 8])); + assert_tys_eq_with_rev(tys!([_[3], uint[8]]), ty!(_[2, 8])); + assert_tys_eq_with_rev(tys!([?, uint[5]]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5], felt]), ty!(?)); assert_tys_eq_with_rev(tys!([bool[5, 2], felt]), ty!(?)); - assert_tys_eq_with_rev(tys!([int[5], felt]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5], felt]), ty!(?)); assert_tys_eq_with_rev(tys!([bool[5, 2], felt]), ty!(?)); - assert_tys_eq_with_rev(tys!([int[5, 2], _]), ty!(?)); - assert_tys_eq_with_rev(tys!([int[5, 2]]), ty!(?)); - assert_tys_eq_with_rev(tys!([int, felt]), ty!(felt[2])); - assert_tys_eq_with_rev(tys!([bool, int]), ty!(felt[2])); - assert_tys_eq_with_rev(tys!([_, int]), ty!(_[2])); - assert_tys_eq_with_rev(tys!([?, int]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5, 2], _]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint[5, 2]]), ty!(?)); + assert_tys_eq_with_rev(tys!([uint, felt]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([bool, uint]), ty!(felt[2])); + assert_tys_eq_with_rev(tys!([_, uint]), ty!(_[2])); + assert_tys_eq_with_rev(tys!([?, uint]), ty!(?)); assert_tys_eq_with_rev( - vec![tys!([int, felt]), tys!([int, felt]), tys!([int, felt])], + vec![tys!([uint, felt]), tys!([uint, felt]), tys!([uint, felt])], ty!(felt[3, 2]), ); assert_tys_eq_with_rev( - vec![tys!([bool, int]), tys!([bool, int]), tys!([bool, int])], + vec![tys!([bool, uint]), tys!([bool, uint]), tys!([bool, uint])], ty!(felt[3, 2]), ); - assert_tys_eq_with_rev(vec![tys!([_, int]), tys!([_, int]), tys!([_, int])], ty!(_[3, 2])); - assert_tys_eq_with_rev(vec![tys!([?, int]), tys!([?, int]), tys!([?, int])], ty!(?)); - assert_tys_eq_with_rev(tys!([felt[5], int[5], bool[5]]), ty!(felt[3, 5])); + assert_tys_eq_with_rev( + vec![tys!([_, uint]), tys!([_, uint]), tys!([_, uint])], + ty!(_[3, 2]), + ); + assert_tys_eq_with_rev(vec![tys!([?, uint]), tys!([?, uint]), tys!([?, uint])], ty!(?)); + assert_tys_eq_with_rev(tys!([felt[5], uint[5], bool[5]]), ty!(felt[3, 5])); } } diff --git a/typing/src/types.rs b/typing/src/types.rs index 3b3cd7c93..aa1a04f96 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -12,7 +12,7 @@ impl core::fmt::Display for ScalarType { match self { Self::Felt => f.write_str("felt"), Self::Bool => f.write_str("bool"), - Self::Int => f.write_str("int"), + Self::Int => f.write_str("uint"), } } } @@ -38,7 +38,7 @@ macro_rules! sty { (bool) => { Some($crate::ScalarType::Bool) }; - (int) => { + (uint) => { Some($crate::ScalarType::Int) }; ($sty:ident) => { @@ -468,11 +468,11 @@ impl BinType { /// /// Assuming shapes are compatible, the following table shows the result type /// based on the scalar types of the operands: - /// ? == ? || felt | bool | int | _ | ? + /// ? == ? || felt | bool | uint | _ | ? /// =========||======|======|======|======|===== /// felt || bool | bool | bool | bool | ? /// bool || bool | bool | bool | bool | ? - /// int || bool | bool | bool | bool | ? + /// uint || bool | bool | bool | bool | ? /// _ || bool | bool | bool | bool | ? /// ? || ? | ? | ? | ? | ? /// @@ -502,21 +502,21 @@ impl BinType { /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. /// /// based on the scalar types of the operands: - /// ? + ? || felt | bool | int | _ | ? + /// ? + ? || felt | bool | uint | _ | ? /// =========||======|======|======|======|===== /// felt || felt | felt | felt | felt | felt /// bool || felt | felt | felt | felt | felt - /// int || felt | felt | int | _ | ? + /// uint || felt | felt | uint | _ | ? /// _ || felt | felt | _ | _ | ? /// ? || felt | felt | ? | ? | ? /// /// So, the result type of an addition is: /// - an error if lhs or rhs is not a scalar type or `?`, /// - symmetric over the operands, - /// - felt + any -> felt - /// - bool + any -> felt - /// - ? + any -> ? - /// - int + int -> int + /// - felt + any -> felt + /// - bool + any -> felt + /// - ? + any -> ? + /// - uint + uint -> uint /// - everything else is an unknown scalar type `_` pub fn infer_bin_ty_add(&self) -> Result, TypeError> { if let Some(ret) = self.result() { @@ -531,7 +531,7 @@ impl BinType { bty!(felt + any) | bty!(any + felt) => Ok(ty!(felt)), bty!(bool + any) | bty!(any + bool) => Ok(ty!(felt)), bty!(? + any) | bty!(any + ?) => Ok(ty!(?)), - bty!(int + int) => Ok(ty!(int)), + bty!(uint + uint) => Ok(ty!(uint)), _ => Ok(ty!(_)), } } @@ -541,33 +541,33 @@ impl BinType { /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. /// /// based on the scalar types of the operands: - /// ? - ? || felt | bool | int | _ | ? + /// ? - ? || felt | bool | uint | _ | ? /// =========||======|======|======|======|===== /// felt || felt | felt | felt | felt | felt /// bool || felt | felt | felt | felt | felt - /// int || felt | felt | int | _ | ? + /// uint || felt | felt | uint | _ | ? /// _ || felt | felt | _ | _ | ? /// ? || felt | felt | ? | ? | ? /// /// So, the result type of a substraction is: /// - an error if either lhs or rhs is not a scalar type or `?`, /// - symmetric over the operands, - /// - felt - any -> felt - /// - bool - any -> felt - /// - int - int -> int - /// - ? - any -> ? + /// - felt - any -> felt + /// - bool - any -> felt + /// - uint - uint -> uint + /// - ? - any -> ? /// - everything else is an unknown scalar type `_` /// /// This is the same as [BinType::infer_bin_ty_add], so it reuses that method. /// /// NOTE: if we refine the types as described in #432, this method will need to be - /// updated to handle the substraction of `bool` and `int` types correctly. + /// updated to handle the substraction of `bool` and `uint` types correctly. /// This will no longer be symmetric over the operands! /// Because: /// - 0 - bool = - bool -> felt /// - bool - 0 -> bool - /// - 0 - int = - int -> int (or error depending on the design) - /// - int - 0 -> int + /// - 0 - uint = - uint -> uint (or error depending on the design) + /// - uint - 0 -> uint /// - 1 - bool -> bool /// - bool - 1 -> felt pub fn infer_bin_ty_sub(&self) -> Result, TypeError> { @@ -579,22 +579,22 @@ impl BinType { /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleShapes]. /// /// based on the scalar types of the operands: - /// ? * ? || felt | bool | int | _ | ? + /// ? * ? || felt | bool | uint | _ | ? /// =========||======|======|======|======|===== /// felt || felt | felt | felt | felt | felt - /// bool || felt | bool | int | _ | ? - /// int || felt | int | int | _ | ? + /// bool || felt | bool | uint | _ | ? + /// uint || felt | uint | uint | _ | ? /// _ || felt | _ | _ | _ | ? /// ? || felt | ? | ? | ? | ? /// /// So, the result type of a multiplication is: /// - an error if either lhs or rhs is not a scalar type or `?`, /// - symmetric over the operands, - /// - felt * any -> felt - /// - ? * any -> ? - /// - _ * any -> _ - /// - int * int -> int - /// - bool * x -> x + /// - felt * any -> felt + /// - ? * any -> ? + /// - _ * any -> _ + /// - uint * uint -> uint + /// - bool * x -> x /// - everything else is an unknown scalar type `_` pub fn infer_bin_ty_mul(&self) -> Result, TypeError> { if let Some(ret) = self.result() { @@ -609,7 +609,7 @@ impl BinType { bty!(felt * any) | bty!(any * felt) => Ok(ty!(felt)), bty!(? * any) | bty!(any * ?) => Ok(ty!(?)), bty!(_ * any) | bty!(any * _) => Ok(ty!(_)), - bty!(int * int) => Ok(ty!(int)), + bty!(uint * uint) => Ok(ty!(uint)), bty!(bool * any:x) | bty!(any:x * bool) => Ok(*x), _ => Ok(ty!(_)), } @@ -620,31 +620,30 @@ impl BinType { /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleBinOp]. /// /// based on the scalar types of the operands: - /// ? ^ ? || felt | bool | int | _ | ? + /// ? ^ ? || felt | bool | uint | _ | ? /// =========||======|======|======|======|===== /// felt || err | err | felt | _ | ? /// bool || err | err | bool | _ | ? - /// int || err | err | int | _ | ? + /// uint || err | err | uint | _ | ? /// _ || err | err | _ | _ | ? /// ? || err | err | ? | ? | ? /// - /// /// So, the result type of an exponentiation is: /// - an error if either lhs or rhs is not a scalar type or `?`, - /// - an error if the rhs is not an int or `?`, - /// - any ^ ? -> ?, - /// - ? ^ any -> ?, - /// - any ^ _ -> _, - /// - any:x ^ int -> lhs, + /// - an error if the rhs is not an uint or `?`, + /// - any ^ ? -> ?, + /// - ? ^ any -> ?, + /// - any ^ _ -> _, + /// - any:x ^ uint -> lhs, /// /// Because: /// - it is an error if either lhs or rhs is not a scalar type or `?`, - /// - it is an error if rhs is not an int or `?`, + /// - it is an error if rhs is not an uint or `?`, /// - a bool to any power is still a bool: /// - 0^n = 0 /// - 1^n = 1 /// - a felt to any power is still a felt - /// - an int to any power is still an int + /// - an uint to any power is still an uint /// - a _ to any power is still a _ /// - a ? to any power is still a ? pub fn infer_bin_ty_exp(&self) -> Result, TypeError> { @@ -662,7 +661,7 @@ impl BinType { }, bty!(any ^ ?) | bty!(? ^ any) => Ok(ty!(?)), bty!(any ^ _) => Ok(ty!(_)), - bty!(any:lhs ^ int) => Ok(*lhs), + bty!(any:lhs ^ uint) => Ok(*lhs), _ => unreachable!("Undefined case for infer_bin_ty_exp: {self}"), } } @@ -721,7 +720,7 @@ mod tests { assert_eq!(sty!(_), None::); assert_eq!(sty!(felt), Some(ScalarType::Felt)); assert_eq!(sty!(bool), Some(ScalarType::Bool)); - assert_eq!(sty!(int), Some(ScalarType::Int)); + assert_eq!(sty!(uint), Some(ScalarType::Int)); } #[test] @@ -730,9 +729,9 @@ mod tests { assert_eq!(ty!(_), Some(Type::Scalar(None))); assert_eq!(ty!(felt), Some(Type::Scalar(Some(ScalarType::Felt)))); assert_eq!(ty!(bool), Some(Type::Scalar(Some(ScalarType::Bool)))); - assert_eq!(ty!(int), Some(Type::Scalar(Some(ScalarType::Int)))); + assert_eq!(ty!(uint), Some(Type::Scalar(Some(ScalarType::Int)))); assert_eq!(ty!(_[5]), Some(Type::Vector(None, 5))); - assert_eq!(ty!(int[5]), Some(Type::Vector(Some(ScalarType::Int), 5))); + assert_eq!(ty!(uint[5]), Some(Type::Vector(Some(ScalarType::Int), 5))); assert_eq!(ty!(_[3, 4]), Some(Type::Matrix(None, 3, 4))); assert_eq!(ty!(felt[3, 4]), Some(Type::Matrix(Some(ScalarType::Felt), 3, 4))); } @@ -762,30 +761,33 @@ mod tests { FunctionType::Evaluator(vec![ty!(felt[1]), ty!(felt[3])]) ); - assert_eq!(fty!(fn(int) -> felt), FunctionType::Function(vec![ty!(int)], ty!(felt))); + assert_eq!(fty!(fn(uint) -> felt), FunctionType::Function(vec![ty!(uint)], ty!(felt))); assert_eq!( - fty!(fn(int[5]) -> felt[3, 4]), - FunctionType::Function(vec![ty!(int[5])], ty!(felt[3, 4]),) + fty!(fn(uint[5]) -> felt[3, 4]), + FunctionType::Function(vec![ty!(uint[5])], ty!(felt[3, 4]),) ); assert_eq!( - fty!(fn(int[5], felt) -> felt[3, 4]), - FunctionType::Function(vec![ty!(int[5]), ty!(felt)], ty!(felt[3, 4]),) + fty!(fn(uint[5], felt) -> felt[3, 4]), + FunctionType::Function(vec![ty!(uint[5]), ty!(felt)], ty!(felt[3, 4]),) ); assert_eq!( - fty!(fn(int[5], felt, bool[3, 4]) -> felt[3, 4]), - FunctionType::Function(vec![ty!(int[5]), ty!(felt), ty!(bool[3, 4]),], ty!(felt[3, 4]),) + fty!(fn(uint[5], felt, bool[3, 4]) -> felt[3, 4]), + FunctionType::Function( + vec![ty!(uint[5]), ty!(felt), ty!(bool[3, 4]),], + ty!(felt[3, 4]), + ) ); } #[test] fn test_macro_bin_type() { - assert_eq!(bty!(int + felt), BinType::Add(ty!(int), ty!(felt), ty!(?))); + assert_eq!(bty!(uint + felt), BinType::Add(ty!(uint), ty!(felt), ty!(?))); assert_eq!(bty!(_ - felt), BinType::Sub(ty!(_), ty!(felt), ty!(?))); assert_eq!(bty!(? = felt), BinType::Eq(ty!(?), ty!(felt), ty!(?))); - assert_eq!(bty!(int + ?), BinType::Add(ty!(int), ty!(?), ty!(?))); - assert_eq!(bty!(int - felt), BinType::Sub(ty!(int), ty!(felt), ty!(?))); - assert_eq!(bty!(int[2] * felt[2]), BinType::Mul(ty!(int[2]), ty!(felt[2]), ty!(?))); - assert_eq!(bty!(int[2, 3] ^ _), BinType::Exp(ty!(int[2, 3]), ty!(_), ty!(?))); + assert_eq!(bty!(uint + ?), BinType::Add(ty!(uint), ty!(?), ty!(?))); + assert_eq!(bty!(uint - felt), BinType::Sub(ty!(uint), ty!(felt), ty!(?))); + assert_eq!(bty!(uint[2] * felt[2]), BinType::Mul(ty!(uint[2]), ty!(felt[2]), ty!(?))); + assert_eq!(bty!(uint[2, 3] ^ _), BinType::Exp(ty!(uint[2, 3]), ty!(_), ty!(?))); assert_eq!(bty!(bool[5] = _[5]), BinType::Eq(ty!(bool[5]), ty!(_[5]), ty!(?))); } @@ -793,8 +795,8 @@ mod tests { fn test_macro_kind() { assert_eq!(kind!(ev([])), Kind::Callable(fty!(ev([])))); assert_eq!(kind!(ev([a])), Kind::Callable(fty!(ev([a])))); - assert_eq!(kind!(fn(int) -> felt), Kind::Callable(fty!(fn(int) -> felt))); - assert_eq!(kind!(int), Kind::Value(ty!(int))); + assert_eq!(kind!(fn(uint) -> felt), Kind::Callable(fty!(fn(uint) -> felt))); + assert_eq!(kind!(uint), Kind::Value(ty!(uint))); assert_eq!(kind!(_), Kind::Value(ty!(_))); assert_eq!(kind!(bool[3, 4]), Kind::Value(ty!(bool[3, 4]))); } From 4891f5adca33ba125e2f6af21a96714332c09b54 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 31 Jul 2025 16:33:47 +0200 Subject: [PATCH 11/42] feat(typing): add an optional span to the TypeError enum --- typing/src/lib.rs | 12 +++++++++--- typing/src/types.rs | 10 +++++----- 2 files changed, 14 insertions(+), 8 deletions(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index 0971e56a4..51bb6099f 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -2,32 +2,38 @@ mod types; use std::fmt::Debug; -use miden_diagnostics::Span; +use miden_diagnostics::{SourceSpan, Span}; pub use types::*; pub enum TypeError { IncompatibleScalarTypes { lhs: Option, rhs: Option, + span: Option, }, IncompatibleShapes { lhs: Option, rhs: Option, + span: Option, }, IncompatibleType { lhs: Option, rhs: Option, + span: Option, }, TypeAlreadySet { lhs: Option, rhs: Option, + span: Option, }, NotASubtype { lhs: Option, rhs: Option, + span: Option, }, IncompatibleBinOp { bin_ty: BinType, + span: Option, }, } @@ -285,7 +291,7 @@ pub trait ScalarTypeMut: Typing { // Allow widening of types *self.scalar_ty_mut() = new_ty; } else { - return Err(TypeError::IncompatibleScalarTypes { lhs: ty, rhs: new_ty }); + return Err(TypeError::IncompatibleScalarTypes { lhs: ty, rhs: new_ty, span: None }); } Ok(()) } @@ -303,7 +309,7 @@ pub trait TypeMut: Typing + ScalarTypeMut { // Allow widening of types *self.ty_mut() = new_ty; } else { - return Err(TypeError::NotASubtype { lhs: ty, rhs: new_ty }); + return Err(TypeError::NotASubtype { lhs: ty, rhs: new_ty, span: None }); } Ok(()) } diff --git a/typing/src/types.rs b/typing/src/types.rs index aa1a04f96..faa41262b 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -493,7 +493,7 @@ impl BinType { if self.lhs().is_shape_compatible(&self.rhs()) { Ok(ty!(bool)) } else { - Err(TypeError::IncompatibleBinOp { bin_ty: *self }) + Err(TypeError::IncompatibleBinOp { bin_ty: *self, span: None }) } } @@ -525,7 +525,7 @@ impl BinType { let lhs = self.lhs(); let rhs = self.rhs(); if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { - return Err(TypeError::IncompatibleShapes { lhs, rhs }); + return Err(TypeError::IncompatibleShapes { lhs, rhs, span: None }); } match self { bty!(felt + any) | bty!(any + felt) => Ok(ty!(felt)), @@ -603,7 +603,7 @@ impl BinType { let lhs = self.lhs(); let rhs = self.rhs(); if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { - return Err(TypeError::IncompatibleShapes { lhs, rhs }); + return Err(TypeError::IncompatibleShapes { lhs, rhs, span: None }); } match self { bty!(felt * any) | bty!(any * felt) => Ok(ty!(felt)), @@ -653,11 +653,11 @@ impl BinType { let lhs = self.lhs(); let rhs = self.rhs(); if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { - return Err(TypeError::IncompatibleBinOp { bin_ty: *self }); + return Err(TypeError::IncompatibleBinOp { bin_ty: *self, span: None }); } match self { bty!(any ^ felt) | bty!(any ^ bool) => { - Err(TypeError::IncompatibleBinOp { bin_ty: *self }) + Err(TypeError::IncompatibleBinOp { bin_ty: *self, span: None }) }, bty!(any ^ ?) | bty!(? ^ any) => Ok(ty!(?)), bty!(any ^ _) => Ok(ty!(_)), From 3ad4b903b5784e6d528bf2e47626e6993d92f848 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 31 Jul 2025 17:50:20 +0200 Subject: [PATCH 12/42] fix(typing): properly rename Int to UInt --- typing/src/lib.rs | 4 ++-- typing/src/types.rs | 12 ++++++------ 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index 51bb6099f..f005ec021 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -155,7 +155,7 @@ pub trait Typing { /// Felt type /// bool: ScalarType::Bool /// Boolean type - /// uint: ScalarType::Int + /// uint: ScalarType::UInt /// Integer type /// /// Subtyping rules: @@ -191,7 +191,7 @@ pub trait Typing { /// Felt type /// bool: Type::Scalar(Some(ScalarType::Bool)) /// Boolean type - /// uint: Type::Scalar(Some(ScalarType::Int)) + /// uint: Type::Scalar(Some(ScalarType::UInt)) /// Integer type /// sty[len]: Type::Vector(Some(sty), len) /// Vector of length `len` with scalar type `sty` diff --git a/typing/src/types.rs b/typing/src/types.rs index faa41262b..0da36a1a9 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -4,7 +4,7 @@ use crate::{TypeError, Typing}; pub enum ScalarType { Felt, Bool, - Int, + UInt, } impl core::fmt::Display for ScalarType { @@ -12,7 +12,7 @@ impl core::fmt::Display for ScalarType { match self { Self::Felt => f.write_str("felt"), Self::Bool => f.write_str("bool"), - Self::Int => f.write_str("uint"), + Self::UInt => f.write_str("uint"), } } } @@ -39,7 +39,7 @@ macro_rules! sty { Some($crate::ScalarType::Bool) }; (uint) => { - Some($crate::ScalarType::Int) + Some($crate::ScalarType::UInt) }; ($sty:ident) => { $sty @@ -720,7 +720,7 @@ mod tests { assert_eq!(sty!(_), None::); assert_eq!(sty!(felt), Some(ScalarType::Felt)); assert_eq!(sty!(bool), Some(ScalarType::Bool)); - assert_eq!(sty!(uint), Some(ScalarType::Int)); + assert_eq!(sty!(uint), Some(ScalarType::UInt)); } #[test] @@ -729,9 +729,9 @@ mod tests { assert_eq!(ty!(_), Some(Type::Scalar(None))); assert_eq!(ty!(felt), Some(Type::Scalar(Some(ScalarType::Felt)))); assert_eq!(ty!(bool), Some(Type::Scalar(Some(ScalarType::Bool)))); - assert_eq!(ty!(uint), Some(Type::Scalar(Some(ScalarType::Int)))); + assert_eq!(ty!(uint), Some(Type::Scalar(Some(ScalarType::UInt)))); assert_eq!(ty!(_[5]), Some(Type::Vector(None, 5))); - assert_eq!(ty!(uint[5]), Some(Type::Vector(Some(ScalarType::Int), 5))); + assert_eq!(ty!(uint[5]), Some(Type::Vector(Some(ScalarType::UInt), 5))); assert_eq!(ty!(_[3, 4]), Some(Type::Matrix(None, 3, 4))); assert_eq!(ty!(felt[3, 4]), Some(Type::Matrix(Some(ScalarType::Felt), 3, 4))); } From 3402d8f0cfa8127ed35d48769a8427092d281958 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Fri, 1 Aug 2025 12:27:50 +0200 Subject: [PATCH 13/42] feat(typing): impl Display for TypeError --- typing/src/lib.rs | 32 ++++++++++++++++++++++++++++++++ 1 file changed, 32 insertions(+) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index f005ec021..635a667f4 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -5,6 +5,7 @@ use std::fmt::Debug; use miden_diagnostics::{SourceSpan, Span}; pub use types::*; +#[derive(Clone, Debug, PartialEq, Eq)] pub enum TypeError { IncompatibleScalarTypes { lhs: Option, @@ -37,6 +38,37 @@ pub enum TypeError { }, } +impl core::fmt::Display for TypeError { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + TypeError::IncompatibleScalarTypes { lhs, rhs, .. } => { + write!(f, "incompatible scalar types: {} and {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::IncompatibleShapes { lhs, rhs, .. } => { + write!(f, "incompatible shapes: {} and {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::IncompatibleType { lhs, rhs, .. } => { + write!(f, "incompatible types: {} and {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::TypeAlreadySet { lhs, rhs, .. } => { + write!(f, "type already set: {} vs {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::NotASubtype { lhs, rhs, .. } => { + write!(f, "type {} is not a subtype of {}", Show(*lhs), Show(*rhs))?; + Ok(()) + }, + TypeError::IncompatibleBinOp { bin_ty, .. } => { + write!(f, "incompatible types for binary operation: {}", bin_ty.show_fn_ty())?; + Ok(()) + }, + } + } +} + pub trait Typing { fn kind(&self) -> Option { Some(Kind::Value(self.ty())) From 60fb953f21c209ae87020f5e1ed946c2f4666dce Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Fri, 1 Aug 2025 13:48:41 +0200 Subject: [PATCH 14/42] refactor(typing): integrate into codebase --- mir/Cargo.toml | 3 +- mir/src/ir/mod.rs | 2 + mir/src/ir/nodes/ops/mod.rs | 3 +- mir/src/ir/nodes/ops/parameter.rs | 8 +- mir/src/ir/nodes/ops/value.rs | 19 +- mir/src/passes/inlining.rs | 16 +- mir/src/passes/mod.rs | 4 +- mir/src/passes/translate.rs | 53 +++-- mir/src/passes/unrolling.rs | 83 ++++---- parser/Cargo.toml | 1 + parser/src/ast/declarations.rs | 60 ++++-- parser/src/ast/expression.rs | 189 +++++++++++------- parser/src/ast/trace.rs | 60 ++++-- parser/src/ast/types.rs | 87 ++------ parser/src/lexer/mod.rs | 11 +- parser/src/parser/grammar.lalrpop | 14 +- .../src/parser/tests/constant_propagation.rs | 12 +- parser/src/parser/tests/functions.rs | 40 ++-- parser/src/parser/tests/mod.rs | 11 +- parser/src/parser/tests/modules.rs | 24 ++- parser/src/sema/binding_type.rs | 31 ++- parser/src/sema/semantic_analysis.rs | 145 ++++++++------ parser/src/transforms/constant_propagation.rs | 10 +- 23 files changed, 479 insertions(+), 407 deletions(-) diff --git a/mir/Cargo.toml b/mir/Cargo.toml index f1bce99c5..faf56c817 100644 --- a/mir/Cargo.toml +++ b/mir/Cargo.toml @@ -14,6 +14,7 @@ edition.workspace = true [dependencies] air-parser = { package = "air-parser", path = "../parser", version = "0.5" } air-pass = { package = "air-pass", path = "../pass", version = "0.5" } +typing = { package = "typing", path = "../typing", version = "0.1" } anyhow = { workspace = true } derive-ir = { package = "air-derive-ir", path = "./derive-ir", version = "0.5" } miden-core = { package = "miden-core", version = "0.13", default-features = false } @@ -21,4 +22,4 @@ miden-diagnostics = { workspace = true } pretty_assertions = "1.4" rand = "0.9" thiserror = { workspace = true } -winter-math = { package = "winter-math", version = "0.12", default-features = false } \ No newline at end of file +winter-math = { package = "winter-math", version = "0.12", default-features = false } diff --git a/mir/src/ir/mod.rs b/mir/src/ir/mod.rs index 9ae93c570..86b7d8630 100644 --- a/mir/src/ir/mod.rs +++ b/mir/src/ir/mod.rs @@ -8,6 +8,7 @@ mod owner; mod quad_eval; mod utils; pub extern crate derive_ir; +pub extern crate typing; pub use bus::Bus; pub use derive_ir::Builder; @@ -18,6 +19,7 @@ pub use node::Node; pub use nodes::*; pub use owner::Owner; pub use quad_eval::{QuadFelt, RandomInputs}; +pub use typing::*; pub use utils::*; /// A trait for nodes that can have children /// This is used with the Child trait to allow for easy traversal and manipulation of the graph diff --git a/mir/src/ir/nodes/ops/mod.rs b/mir/src/ir/nodes/ops/mod.rs index 64d1674a3..22a472075 100644 --- a/mir/src/ir/nodes/ops/mod.rs +++ b/mir/src/ir/nodes/ops/mod.rs @@ -29,8 +29,9 @@ pub use matrix::Matrix; pub use mul::Mul; pub use parameter::Parameter; pub use sub::Sub; +pub use typing::*; pub use value::{ - BusAccess, ConstantValue, MirType, MirValue, PeriodicColumnAccess, PublicInputAccess, + BusAccess, ConstantValue, MirValue, PeriodicColumnAccess, PublicInputAccess, PublicInputTableAccess, SpannedMirValue, TraceAccess, TraceAccessBinding, Value, }; pub use vector::Vector; diff --git a/mir/src/ir/nodes/ops/parameter.rs b/mir/src/ir/nodes/ops/parameter.rs index 939215d27..4a1bdfe44 100644 --- a/mir/src/ir/nodes/ops/parameter.rs +++ b/mir/src/ir/nodes/ops/parameter.rs @@ -2,8 +2,8 @@ use std::hash::{Hash, Hasher}; use miden_diagnostics::{SourceSpan, Spanned}; -use super::MirType; use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Singleton}; +use typing::*; /// A MIR operation to represent a `Parameter` in a function or evaluator. /// Also used in If and For loops to represent declared parameters. @@ -16,19 +16,19 @@ pub struct Parameter { /// The position of the `Parameter` in the referred node's `Parameter` list pub position: usize, /// The type of the `Parameter` - pub ty: MirType, + pub ty: Option, pub _node: Singleton, #[span] pub span: SourceSpan, } impl Parameter { - pub fn create(position: usize, ty: MirType, span: SourceSpan) -> Link { + pub fn create(position: usize, ty: Type, span: SourceSpan) -> Link { Op::Parameter(Self { parents: Vec::default(), ref_node: BackLink::none(), position, - ty, + ty: Some(ty), _node: Singleton::none(), span, }) diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index f80c7d78e..868f8b621 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -4,6 +4,7 @@ use air_parser::ast::{ use miden_diagnostics::{SourceSpan, Spanned}; use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Singleton}; +use typing::*; /// A MIR operation to represent a known value, [Value]. /// @@ -149,24 +150,6 @@ pub struct SpannedMirValue { pub value: MirValue, } -#[derive(Debug, Default, Eq, PartialEq, Clone, Hash)] -pub enum MirType { - #[default] - Felt, - Vector(usize), - Matrix(usize, usize), -} - -impl From for MirType { - fn from(value: ast::Type) -> Self { - match value { - ast::Type::Felt => MirType::Felt, - ast::Type::Vector(n) => MirType::Vector(n), - ast::Type::Matrix(cols, rows) => MirType::Matrix(cols, rows), - } - } -} - /// Represents an access of a PeriodicColumn, similar in nature to [TraceAccess]. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash)] pub struct PeriodicColumnAccess { diff --git a/mir/src/passes/inlining.rs b/mir/src/passes/inlining.rs index 08115cdac..b7ae9ae65 100644 --- a/mir/src/passes/inlining.rs +++ b/mir/src/passes/inlining.rs @@ -7,8 +7,8 @@ use super::{duplicate_node_or_replace, visitor::Visitor}; use crate::{ CompileError, ir::{ - Accessor, Graph, Link, Mir, MirType, MirValue, Node, Op, Parameter, Parent, Root, - SpannedMirValue, TraceAccessBinding, Value, Vector, + Accessor, Graph, Link, Mir, MirValue, Node, Op, Parameter, Parent, Root, SpannedMirValue, + TraceAccessBinding, Type, Value, Vector, }, }; @@ -515,8 +515,8 @@ fn check_evaluator_argument_sizes( } else if let Some(parameter) = child.as_parameter() { let Parameter { ty, .. } = parameter.deref(); let size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, + Some(Type::Scalar(_)) => 1, + Some(Type::Vector(_, len)) => *len, _ => unreachable!("expected felt or vector, got {:?}", ty), }; trace_segments_arg_vector_len += size; @@ -535,8 +535,8 @@ fn check_evaluator_argument_sizes( } else if let Some(parameter) = indexable.as_parameter() { let Parameter { ty, .. } = parameter.deref(); let size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, + Some(Type::Scalar(_)) => 1, + Some(Type::Vector(_, len)) => *len, _ => unreachable!("expected felt or vector, got {:?}", ty), }; trace_segments_arg_vector_len += size; @@ -637,8 +637,8 @@ fn unpack_evaluator_arguments(args: &[Link]) -> Vec> { } else if let Some(parameter) = indexable.as_parameter() { let Parameter { ty, .. } = parameter.deref(); let _size = match ty { - MirType::Felt => 1, - MirType::Vector(len) => *len, + Some(Type::Scalar(_)) => 1, + Some(Type::Vector(_, len)) => *len, _ => unreachable!("expected felt or vector, got {:?}", ty), }; diff --git a/mir/src/passes/mod.rs b/mir/src/passes/mod.rs index cfc1c9126..3c98fcd7d 100644 --- a/mir/src/passes/mod.rs +++ b/mir/src/passes/mod.rs @@ -186,7 +186,7 @@ pub fn duplicate_node( .to_link() .unwrap_or_else(|| panic!("invalid ref_node for parameter {parameter:?}",)); let new_param = - Parameter::create(parameter.position, parameter.ty.clone(), parameter.span()); + Parameter::create(parameter.position, parameter.ty.unwrap(), parameter.span()); if let Some(_root_ref) = owner_ref.as_root() { new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref); @@ -431,7 +431,7 @@ pub fn duplicate_node_or_replace( current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); } else { let new_param = - Parameter::create(parameter.position, parameter.ty.clone(), parameter.span()); + Parameter::create(parameter.position, parameter.ty.unwrap(), parameter.span()); if let Some(_root_ref) = owner_ref.as_root() { new_param.as_parameter_mut().unwrap().set_ref_node(owner_ref.clone()); diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 94591352a..282bb0ee1 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -4,14 +4,15 @@ use std::ops::Deref; use air_parser::{LexicalScope, ast, ast::AccessType, symbols}; use air_pass::Pass; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; +use typing::*; use crate::{ CompileError, ir::{ Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, - MirType, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, - Root, SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, Vector, + MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, Root, + SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Type, Value, Vector, }, passes::duplicate_node, }; @@ -189,7 +190,7 @@ impl<'a> MirBuilder<'a> { for binding in trace_segment.bindings.iter() { let name = binding.name.as_ref(); match &binding.ty { - ast::Type::Vector(size) => { + Type::Vector(_, size) => { let mut params_vec = Vec::new(); let mut span = SourceSpan::UNKNOWN; for _ in 0..*size { @@ -203,7 +204,7 @@ impl<'a> MirBuilder<'a> { let vector_node = Vector::create(params_vec, span); self.bindings.insert(name.unwrap(), vector_node.clone()); }, - ast::Type::Felt => { + Type::Scalar(_) => { let param = all_params_flatten_for_trace_segment[i].clone(); i += 1; self.bindings.insert(name.unwrap(), param.clone()); @@ -236,7 +237,7 @@ impl<'a> MirBuilder<'a> { func = func.parameters(param.clone()); } i += 1; - let ret = Parameter::create(i, self.translate_type(&ast_func.return_type), ast_func.span()); + let ret = Parameter::create(i, ast_func.return_type, ast_func.span()); params.push(ret.clone()); let func = func.return_type(ret).build(); @@ -272,25 +273,25 @@ impl<'a> MirBuilder<'a> { &mut self, span: SourceSpan, name: Option<&'a ast::Identifier>, - ty: &ast::Type, + ty: &Type, i: &mut usize, ) -> Result>, CompileError> { match ty { - ast::Type::Felt => { - let param = Parameter::create(*i, MirType::Felt, span); + Type::Scalar(_) => { + let param = Parameter::create(*i, ty!(felt).unwrap(), span); *i += 1; Ok(vec![param]) }, - ast::Type::Vector(size) => { + Type::Vector(_, size) => { let mut params = Vec::new(); for _ in 0..*size { - let param = Parameter::create(*i, MirType::Felt, span); + let param = Parameter::create(*i, ty!(felt[*size]).unwrap(), span); *i += 1; params.push(param); } Ok(params) }, - ast::Type::Matrix(_rows, _cols) => { + Type::Matrix(..) => { let span = if let Some(name) = name { name.span() } else { @@ -310,21 +311,21 @@ impl<'a> MirBuilder<'a> { &mut self, span: SourceSpan, name: Option<&'a ast::Identifier>, - ty: &ast::Type, + ty: &Type, i: &mut usize, ) -> Result, CompileError> { match ty { - ast::Type::Felt => { - let param = Parameter::create(*i, MirType::Felt, span); + Type::Scalar(_) => { + let param = Parameter::create(*i, *ty, span); *i += 1; Ok(param) }, - ast::Type::Vector(size) => { - let param = Parameter::create(*i, MirType::Vector(*size), span); + Type::Vector(..) => { + let param = Parameter::create(*i, *ty, span); *i += 1; Ok(param) }, - ast::Type::Matrix(_rows, _cols) => { + Type::Matrix(..) => { let span = if let Some(name) = name { name.span() } else { @@ -364,14 +365,6 @@ impl<'a> MirBuilder<'a> { Ok(func) } - fn translate_type(&mut self, ty: &ast::Type) -> MirType { - match ty { - ast::Type::Felt => MirType::Felt, - ast::Type::Vector(size) => MirType::Vector(*size), - ast::Type::Matrix(rows, cols) => MirType::Matrix(*rows, *cols), - } - } - fn translate_statement(&mut self, stmt: &'a ast::Statement) -> Result, CompileError> { match stmt { ast::Statement::Let(let_stmt) => self.translate_let(let_stmt), @@ -451,7 +444,8 @@ impl<'a> MirBuilder<'a> { self.bindings.enter(); for (index, binding) in list_comp.bindings.iter().enumerate() { - let binding_node = Parameter::create(index, ast::Type::Felt.into(), binding.span()); + // TODO: extract the type from the bound variable + let binding_node = Parameter::create(index, ty!(felt).unwrap(), binding.span()); params.push(binding_node.clone()); self.bindings.insert(binding, binding_node); } @@ -941,7 +935,8 @@ impl<'a> MirBuilder<'a> { self.bindings.enter(); let mut params = Vec::new(); for (index, binding) in list_comp.bindings.iter().enumerate() { - let binding_node = Parameter::create(index, ast::Type::Felt.into(), binding.span()); + // TODO: extract the type from the bound variable + let binding_node = Parameter::create(index, ty!(felt).unwrap(), binding.span()); params.push(binding_node.clone()); self.bindings.insert(binding, binding_node); } @@ -1165,8 +1160,8 @@ impl<'a> MirBuilder<'a> { // // In that case, replacing the default type (Felt) with the one from the access if let Some(mut param) = let_bound_access_expr.as_parameter_mut() { - if let Some(access_ty) = &access.ty { - param.ty = self.translate_type(access_ty); + if let Some(_) = &access.ty { + param.ty = access.ty } } let accessor: Link = Accessor::create( diff --git a/mir/src/passes/unrolling.rs b/mir/src/passes/unrolling.rs index 3674c2d32..283d6092f 100644 --- a/mir/src/passes/unrolling.rs +++ b/mir/src/passes/unrolling.rs @@ -815,9 +815,12 @@ impl UnrollingFirstPass<'_> { AccessType::Matrix(..) => 1, }, Op::Parameter(parameter) => match parameter.ty { - MirType::Felt => 1, - MirType::Vector(l) => l, - MirType::Matrix(l, _) => l, + Some(Type::Scalar(_)) => 1, + Some(Type::Vector(_, l)) => l, + Some(Type::Matrix(_, l, _)) => l, + _ => { + unreachable!("Parameter should have a type"); // Raise diag + }, }, _ => 1, } @@ -862,8 +865,11 @@ impl UnrollingFirstPass<'_> { let mut new_vec = vec![]; for i in 0..iterator_expected_len { - let new_node = - Parameter::create(i, MirType::Felt, for_node.as_for().unwrap().deref().span()); + let new_node = Parameter::create( + i, + ty!(felt).unwrap(), + for_node.as_for().unwrap().deref().span(), + ); new_vec.push(new_node.clone()); let iterators_i = iterators @@ -1077,48 +1083,47 @@ impl Visitor for UnrollingSecondPass<'_> { let new_node = self.nodes_to_replace.get(&body.get_ptr()).unwrap().1.clone(); // If there is a selector, we need to enforce it on the body - let new_node_with_selector_if_needed = if let Some(selector) = - self.for_inlining_context.clone().unwrap().selector - { - if let Op::Vector(new_node_vector) = new_node.borrow().deref() { - let new_node_vec = new_node_vector.children().borrow().deref().clone(); - let mut new_vec = vec![]; - for new_node_child in new_node_vec.into_iter() { + let new_node_with_selector_if_needed = + if let Some(selector) = self.for_inlining_context.clone().unwrap().selector { + if let Op::Vector(new_node_vector) = new_node.borrow().deref() { + let new_node_vec = new_node_vector.children().borrow().deref().clone(); + let mut new_vec = vec![]; + for new_node_child in new_node_vec.into_iter() { + let zero_node = Value::create(SpannedMirValue { + span: Default::default(), + value: MirValue::Constant(ConstantValue::Felt(0)), + }); + // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> + // Enf(Sub(lhs, rhs) == 0), but it introduces an + // unnecessary zero node + let new_node_child_with_selector = Sub::create( + Mul::create( + duplicate_node(selector.clone(), &mut HashMap::new()), + new_node_child, + root.span(), + ), + zero_node, + root.span(), + ); + new_vec.push(new_node_child_with_selector); + } + Vector::create(new_vec, root.span()) + } else { let zero_node = Value::create(SpannedMirValue { span: Default::default(), value: MirValue::Constant(ConstantValue::Felt(0)), }); - // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> - // Enf(Sub(lhs, rhs) == 0), but it introduces an - // unnecessary zero node - let new_node_child_with_selector = Sub::create( - Mul::create( - duplicate_node(selector.clone(), &mut HashMap::new()), - new_node_child, - root.span(), - ), + // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> Enf(Sub(lhs, + // rhs) == 0), but it introduces an unnecessary zero node + Sub::create( + Mul::create(selector, new_node, root.span()), zero_node, root.span(), - ); - new_vec.push(new_node_child_with_selector); + ) } - Vector::create(new_vec, root.span()) } else { - let zero_node = Value::create(SpannedMirValue { - span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(0)), - }); - // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> Enf(Sub(lhs, - // rhs) == 0), but it introduces an unnecessary zero node - Sub::create( - Mul::create(selector, new_node, root.span()), - zero_node, - root.span(), - ) - } - } else { - new_node - }; + new_node + }; root.as_op().unwrap().set(&new_node_with_selector_if_needed); diff --git a/parser/Cargo.toml b/parser/Cargo.toml index c50738c40..29b8029db 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -16,6 +16,7 @@ lalrpop = { version = "0.20", default-features = false } [dependencies] air-pass = { package = "air-pass", path = "../pass", version = "0.5" } +typing = { package = "typing", path = "../typing", version = "0.1" } either = "1.12" lalrpop-util = "0.20" lazy_static = "1.4" diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index 3939c7860..e0d824698 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -144,11 +144,15 @@ impl Constant { pub const fn new(span: SourceSpan, name: Identifier, value: ConstantExpr) -> Self { Self { span, name, value } } - +} +impl Typing for Constant { /// Gets the type of the value associated with this constant - pub fn ty(&self) -> Type { + fn ty(&self) -> Option { self.value.ty() } + fn kind(&self) -> Option { + self.value.kind() + } } impl Eq for Constant {} impl PartialEq for Constant { @@ -168,25 +172,24 @@ pub enum ConstantExpr { Vector(Vec), Matrix(Vec>), } -impl ConstantExpr { +impl Typing for ConstantExpr { /// Gets the type of this expression - pub fn ty(&self) -> Type { + fn ty(&self) -> Option { match self { - Self::Scalar(_) => Type::Felt, - Self::Vector(elems) => Type::Vector(elems.len()), + Self::Scalar(_) => ty!(uint), + Self::Vector(elems) => ty!(uint[elems.len()]), Self::Matrix(rows) => { let num_rows = rows.len(); let num_cols = rows.first().unwrap().len(); - Type::Matrix(num_rows, num_cols) + ty!(uint[num_rows, num_cols]) }, } } - - /// Returns true if this expression is of aggregate type - pub fn is_aggregate(&self) -> bool { - matches!(self, Self::Vector(_) | Self::Matrix(_)) + fn kind(&self) -> Option { + self.ty().kind() } } + impl fmt::Display for ConstantExpr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { @@ -255,14 +258,21 @@ impl Export<'_> { Self::Evaluator(item) => item.name, } } - +} +impl Typing for Export<'_> { /// Returns the type of the value associated with this export /// /// NOTE: Evaluator functions have no return value, so they have no type associated. /// For this reason, this function returns `Option` rather than `Type`. - pub fn ty(&self) -> Option { + fn ty(&self) -> Option { match self { - Self::Constant(item) => Some(item.ty()), + Self::Constant(item) => item.ty(), + Self::Evaluator(_) => None, + } + } + fn kind(&self) -> Option { + match self { + Self::Constant(item) => item.kind(), Self::Evaluator(_) => None, } } @@ -373,16 +383,19 @@ pub struct EvaluatorFunction { pub name: Identifier, pub params: Vec, pub body: Vec, + pub fn_ty: FunctionType, } impl EvaluatorFunction { /// Creates a new function. - pub const fn new( + pub fn new( span: SourceSpan, name: Identifier, params: Vec, body: Vec, ) -> Self { - Self { span, name, params, body } + let p = params.iter().map(|ty| ty.ty()).collect::>(); + let fn_ty = FunctionType::Evaluator(p); + Self { span, name, params, body, fn_ty } } } impl Eq for EvaluatorFunction {} @@ -405,17 +418,28 @@ pub struct Function { pub params: Vec<(Identifier, Type)>, pub return_type: Type, pub body: Vec, + pub fn_ty: FunctionType, } impl Function { /// Creates a new function. - pub const fn new( + pub fn new( span: SourceSpan, name: Identifier, params: Vec<(Identifier, Type)>, return_type: Type, body: Vec, ) -> Self { - Self { span, name, params, return_type, body } + let p = params.iter().map(|(_, ty)| ty.ty()).collect::>(); + let r = return_type.ty(); + let fn_ty = FunctionType::Function(p, r); + Self { + span, + name, + params, + return_type, + body, + fn_ty, + } } pub fn param_types(&self) -> Vec { diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index 49b0f285a..c88b36593 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -318,31 +318,27 @@ impl Expr { _ => false, } } - +} +impl Typing for Expr { /// Returns the resolved type of this expression, if known - pub fn ty(&self) -> Option { + fn ty(&self) -> Option { match self { - Self::Const(constant) => Some(constant.ty()), + Self::Const(constant) => constant.ty(), Self::Range(range) => range.ty(), - Self::Vector(vector) => match vector.first().and_then(|e| e.ty()) { - Some(Type::Felt) => Some(Type::Vector(vector.len())), - Some(Type::Vector(n)) => Some(Type::Matrix(vector.len(), n)), - Some(_) => None, - None => Some(Type::Vector(0)), - }, - Self::Matrix(matrix) => { - let rows = matrix.len(); - let cols = matrix[0].len(); - Some(Type::Matrix(rows, cols)) - }, + Self::Vector(vector) => vector.ty(), + Self::Matrix(matrix) => matrix.ty(), Self::SymbolAccess(access) => access.ty, - Self::Binary(_) => Some(Type::Felt), + Self::Binary(bin_expr) => bin_expr.ty(), Self::Call(call) => call.ty, Self::ListComprehension(lc) => lc.ty, Self::Let(let_expr) => let_expr.ty(), - Self::BusOperation(_) | Self::Null(_) | Self::Unconstrained(_) => Some(Type::Felt), + Self::BusOperation(_) | Self::Null(_) | Self::Unconstrained(_) => ty!(felt), } } + + fn kind(&self) -> Option { + self.ty().kind() + } } impl fmt::Debug for Expr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { @@ -529,30 +525,30 @@ impl ScalarExpr { _ => false, } } - +} +impl Typing for ScalarExpr { /// Returns the resolved type of this expression, if known. /// /// Returns `Ok(Some)` if the type could be resolved without conflict. /// Returns `Ok(None)` if type information was missing. /// Returns `Err` if the type could not be resolved due to a conflict, /// with a span covering the source of the conflict. - pub fn ty(&self) -> Result, SourceSpan> { + fn infer_ty(&self) -> Result, TypeError> { match self { - Self::Const(_) => Ok(Some(Type::Felt)), + Self::Const(_) => Ok(ty!(uint)), Self::SymbolAccess(sym) => Ok(sym.ty), Self::BoundedSymbolAccess(sym) => Ok(sym.column.ty), - Self::Binary(expr) => match (expr.lhs.ty()?, expr.rhs.ty()?) { - (None, _) | (_, None) => Ok(None), - (Some(lty), Some(rty)) if lty == rty => Ok(Some(lty)), - _ => Err(expr.span()), - }, + Self::Binary(expr) => expr.infer_ty(), Self::Call(expr) => Ok(expr.ty), Self::Let(expr) => Ok(expr.ty()), Self::BusOperation(_) | ScalarExpr::Null(_) | ScalarExpr::Unconstrained(_) => { - Ok(Some(Type::Felt)) + Ok(ty!(felt)) }, } } + fn ty(&self) -> Option { + self.infer_ty().ok()? + } } impl TryFrom for ScalarExpr { type Error = InvalidExprError; @@ -706,7 +702,7 @@ impl RangeExpr { pub fn ty(&self) -> Option { match (&self.start, &self.end) { (RangeBound::Const(start), RangeBound::Const(end)) => { - Some(Type::Vector(end.item.abs_diff(start.item))) + ty!(uint[end.item.abs_diff(start.item)]) }, _ => None, } @@ -776,14 +772,31 @@ pub struct BinaryExpr { pub op: BinaryOp, pub lhs: Box, pub rhs: Box, + pub bin_ty: BinType, } impl BinaryExpr { pub fn new(span: SourceSpan, op: BinaryOp, lhs: ScalarExpr, rhs: ScalarExpr) -> Self { + debug_assert!( + lhs.ty().is_none() || rhs.ty().is_none() || (lhs.is_scalar() && rhs.is_scalar()), + "binary expression operands must both be scalars, got: {} and {}", + lhs.show_ty(), + rhs.show_ty(), + ); + let l_ty = lhs.scalar_ty(); + let r_ty = rhs.scalar_ty(); + let bin_ty = match op { + BinaryOp::Eq => bty!(l_ty = r_ty), + BinaryOp::Add => bty!(l_ty + r_ty), + BinaryOp::Sub => bty!(l_ty - r_ty), + BinaryOp::Mul => bty!(l_ty * r_ty), + BinaryOp::Exp => bty!(l_ty ^ r_ty), + }; Self { span, op, lhs: Box::new(lhs), rhs: Box::new(rhs), + bin_ty, } } @@ -814,6 +827,24 @@ impl fmt::Display for BinaryExpr { write!(f, "{} {} {}", &self.lhs, &self.op, &self.rhs) } } +impl Typing for BinaryExpr { + fn scalar_ty(&self) -> Option { + self.bin_ty.scalar_ty() + } + fn ty(&self) -> Option { + self.bin_ty.ty() + } +} +impl ScalarTypeMut for BinaryExpr { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.bin_ty.scalar_ty_mut() + } +} +impl TypeMut for BinaryExpr { + fn ty_mut(&mut self) -> &mut Option { + self.bin_ty.ty_mut() + } +} #[derive(Debug, Copy, Clone, PartialEq, Eq)] pub enum BinaryOp { @@ -858,6 +889,12 @@ impl fmt::Display for Boundary { } } +pub trait Access { + type Accessed; + /// Return a new [Type] representing the type of the value produced by the given [AccessType] + fn access(&self, access_type: AccessType) -> Result; +} + /// Represents the way an identifier is accessed/referenced in the source. #[derive(Hash, Debug, Clone, Eq, PartialEq, Default)] pub enum AccessType { @@ -973,51 +1010,55 @@ impl SymbolAccess { match access_type { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { - Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), - Type::Vector(_) => Ok(Self { + Type::Scalar(_) => Err(InvalidAccessError::IndexIntoScalar), + Type::Vector(_, len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Index(idx), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(_, rows, _) if idx >= rows => { + Err(InvalidAccessError::IndexOutOfBounds) + }, + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Index(idx), - ty: Some(Type::Vector(cols)), + ty: ty!(sty[cols]), ..self.clone() }), }, AccessType::Slice(range) => { let slice_range = range.to_slice_range(); let rlen = slice_range.end - slice_range.start; + // TODO: check if this is valid: + // let rlen = slice_range.end.abs_diff(slice_range.start); match ty { - Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if slice_range.end > len => { + Type::Scalar(_) => Err(InvalidAccessError::IndexIntoScalar), + Type::Vector(_, len) if slice_range.end > len => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Vector(_) => Ok(Self { + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Slice(range), - ty: Some(Type::Vector(rlen)), + ty: ty!(sty[rlen]), ..self.clone() }), - Type::Matrix(rows, _) if slice_range.end > rows => { + Type::Matrix(_, rows, _) if slice_range.end > rows => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Slice(range), - ty: Some(Type::Matrix(rlen, cols)), + ty: ty!(sty[rlen, cols]), ..self.clone() }), } }, AccessType::Matrix(row, col) => match ty { - Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar), - Type::Matrix(rows, cols) if row >= rows || col >= cols => { + Type::Scalar(_) | Type::Vector(..) => Err(InvalidAccessError::IndexIntoScalar), + Type::Matrix(_, rows, cols) if row >= rows || col >= cols => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Matrix(..) => Ok(Self { + Type::Matrix(sty, ..) => Ok(Self { access_type: AccessType::Matrix(row, col), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), }, @@ -1033,17 +1074,19 @@ impl SymbolAccess { match access_type { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { - Type::Felt => unreachable!(), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), - Type::Vector(_) => Ok(Self { + Type::Scalar(_) => unreachable!(), + Type::Vector(_, len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Index(base_range.start + idx), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(_, rows, _) if idx >= rows => { + Err(InvalidAccessError::IndexOutOfBounds) + }, + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Index(base_range.start + idx), - ty: Some(Type::Vector(cols)), + ty: ty!(sty[cols]), ..self.clone() }), }, @@ -1059,33 +1102,33 @@ impl SymbolAccess { end: RangeBound::Const(Span::new(range.end.span(), end)), }; match ty { - Type::Felt => unreachable!(), - Type::Vector(_) if slice_range.end > blen => { + Type::Scalar(_) => unreachable!(), + Type::Vector(..) if slice_range.end > blen => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Vector(_) => Ok(Self { + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Slice(shifted), - ty: Some(Type::Vector(rlen)), + ty: ty!(sty[rlen]), ..self.clone() }), - Type::Matrix(rows, _) if slice_range.end > rows => { + Type::Matrix(_, rows, _) if slice_range.end > rows => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Slice(shifted), - ty: Some(Type::Matrix(rlen, cols)), + ty: ty!(sty[rlen, cols]), ..self.clone() }), } }, AccessType::Matrix(row, col) => match ty { - Type::Felt | Type::Vector(_) => Err(InvalidAccessError::IndexIntoScalar), - Type::Matrix(rows, cols) if row >= rows || col >= cols => { + Type::Scalar(_) | Type::Vector(..) => Err(InvalidAccessError::IndexIntoScalar), + Type::Matrix(_, rows, cols) if row >= rows || col >= cols => { Err(InvalidAccessError::IndexOutOfBounds) }, - Type::Matrix(..) => Ok(Self { + Type::Matrix(sty, ..) => Ok(Self { access_type: AccessType::Matrix(row, col), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), }, @@ -1101,17 +1144,19 @@ impl SymbolAccess { match access_type { AccessType::Default => Ok(self.clone()), AccessType::Index(idx) => match ty { - Type::Felt => Err(InvalidAccessError::IndexIntoScalar), - Type::Vector(len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), - Type::Vector(_) => Ok(Self { + Type::Scalar(_) => Err(InvalidAccessError::IndexIntoScalar), + Type::Vector(_, len) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), + Type::Vector(sty, _) => Ok(Self { access_type: AccessType::Matrix(base_idx, idx), - ty: Some(Type::Felt), + ty: ty!(sty), ..self.clone() }), - Type::Matrix(rows, _) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), - Type::Matrix(_, cols) => Ok(Self { + Type::Matrix(_, rows, _) if idx >= rows => { + Err(InvalidAccessError::IndexOutOfBounds) + }, + Type::Matrix(sty, _, cols) => Ok(Self { access_type: AccessType::Matrix(base_idx, idx), - ty: Some(Type::Vector(cols)), + ty: ty!(sty[cols]), ..self.clone() }), }, @@ -1383,13 +1428,15 @@ impl Call { /// Constructs a function call for the `sum` reducer/fold #[inline] pub fn sum(span: SourceSpan, args: Vec) -> Self { - Self::new_builtin(span, "sum", args, Type::Felt) + // TODO: adapt to the new type system and use BinType instead of this type + Self::new_builtin(span, "sum", args, ty!(felt).unwrap()) } /// Constructs a function call for the `prod` reducer/fold #[inline] pub fn prod(span: SourceSpan, args: Vec) -> Self { - Self::new_builtin(span, "prod", args, Type::Felt) + // TODO: adapt to the new type system and use BinType instead of this type + Self::new_builtin(span, "prod", args, ty!(felt).unwrap()) } fn new_builtin(span: SourceSpan, name: &str, args: Vec, ty: Type) -> Self { diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index 3e1247bb4..333b5dcd6 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -1,6 +1,7 @@ use std::fmt; use miden_diagnostics::{SourceSpan, Spanned}; +use typing::{FunctionType, Kind, Typing, tty, ty}; use super::*; @@ -27,6 +28,7 @@ pub struct TraceSegment { /// A vector of `size` elements which tracks for every column whether a /// constraint has been applied to that column, and on what boundaries. pub boundary_constrained: Vec>, + pub fn_ty: Option, } impl TraceSegment { /// Constructs a new [TraceSegment] given a span, segment id, name, and a vector of (Identifier, @@ -42,16 +44,21 @@ impl TraceSegment { for binding in raw_bindings.into_iter() { let (name, size) = binding.item; let ty = match size { - 1 => Type::Felt, - n => Type::Vector(n), - }; + 1 => tty!(name), + n => tty!(name[n]), + } + .unwrap_or_else(|| { + unreachable!( + "Trace segment binding types should always be known, but got None for {name} with size {size}" + ) + }); bindings.push(TraceBinding::new(binding.span(), name, id, offset, size, ty)); offset += size; } // The size of the segment is the sum of the sizes of all the bindings let size = offset; - Self { + let mut res = Self { span, id, name, @@ -61,7 +68,13 @@ impl TraceSegment { Span::new(SourceSpan::UNKNOWN, ColumnBoundaryFlags::EMPTY); size ], - } + fn_ty: None, + }; + res.fn_ty = match res.kind() { + Some(Kind::Callable(fty)) => Some(fty), + _ => None, + }; + res } /// Returns true if `column` is constrained on `boundary` @@ -95,6 +108,16 @@ impl TraceSegment { self.size == 0 } } +impl Typing for TraceSegment { + fn ty(&self) -> Option { + None + } + fn kind(&self) -> Option { + Some(Kind::Callable(FunctionType::Evaluator( + self.bindings.iter().map(|b| b.ty()).collect(), + ))) + } +} impl fmt::Debug for TraceSegment { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { f.debug_struct("TraceSegment") @@ -227,20 +250,22 @@ impl TraceBinding { ty, } } +} +impl Typing for TraceBinding { /// Returns a [Type] that describes what type of value this binding represents - #[inline] - pub fn ty(&self) -> Type { - self.ty + fn ty(&self) -> Option { + Some(self.ty) } - - #[inline] - pub fn is_scalar(&self) -> bool { - self.ty.is_scalar() + fn kind(&self) -> Option { + Some(Kind::Value(self.ty())) } +} +impl Access for TraceBinding { + type Accessed = Self; /// Derive a new [TraceBinding] derived from the current one given an [AccessType] - pub fn access(&self, access_type: AccessType) -> Result { + fn access(&self, access_type: AccessType) -> Result { match access_type { AccessType::Default => Ok(*self), AccessType::Slice(_) if self.is_scalar() => Err(InvalidAccessError::SliceOfScalar), @@ -254,7 +279,7 @@ impl TraceBinding { Ok(Self { offset, size, - ty: Type::Vector(size), + ty: ty!(felt[size]).unwrap(), ..*self }) } @@ -263,7 +288,12 @@ impl TraceBinding { AccessType::Index(idx) if idx >= self.size => Err(InvalidAccessError::IndexOutOfBounds), AccessType::Index(idx) => { let offset = self.offset + idx; - Ok(Self { offset, size: 1, ty: Type::Felt, ..*self }) + Ok(Self { + offset, + size: 1, + ty: ty!(felt).unwrap(), + ..*self + }) }, AccessType::Matrix(..) => Err(InvalidAccessError::IndexIntoScalar), } diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index 3155180ec..691ee8c0a 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -1,106 +1,45 @@ use super::*; +pub use typing::*; +pub use typing::{bty, fty, kind, sty, tty, ty, tys}; -/// The types of values which can be represented in an AirScript program -#[derive(Hash, Debug, Copy, Clone, PartialEq, Eq)] -pub enum Type { - /// A field element - Felt, - /// A vector of N integers - Vector(usize), - /// A matrix of N rows and M columns - Matrix(usize, usize), -} -impl Type { - /// Returns true if this type is an aggregate - #[inline] - pub fn is_aggregate(&self) -> bool { - match self { - Self::Felt => false, - Self::Vector(_) | Self::Matrix(..) => true, - } - } - - /// Returns true if this type is a scalar - #[inline] - pub fn is_scalar(&self) -> bool { - matches!(self, Self::Felt) - } - - /// Returns true if this type is a valid iterable in a comprehension - #[inline] - pub fn is_iterable(&self) -> bool { - self.is_vector() - } - - /// Returns true if this type is a vector - #[inline] - pub fn is_vector(&self) -> bool { - matches!(self, Self::Vector(_)) - } - +impl Access for Type { + type Accessed = Self; /// Return a new [Type] representing the type of the value produced by the given [AccessType] - pub fn access(&self, access_type: AccessType) -> Result { + fn access(&self, access_type: AccessType) -> Result { match *self { ty if access_type == AccessType::Default => Ok(ty), - Self::Felt => Err(InvalidAccessError::IndexIntoScalar), - Self::Vector(len) => match access_type { + Self::Scalar(_) => Err(InvalidAccessError::IndexIntoScalar), + Self::Vector(sty, len) => match access_type { AccessType::Slice(range) => { let slice_range = range.to_slice_range(); if slice_range.end > len { Err(InvalidAccessError::IndexOutOfBounds) } else { - Ok(Self::Vector(slice_range.len())) + Ok(Self::Vector(sty, slice_range.len())) } }, AccessType::Index(idx) if idx >= len => Err(InvalidAccessError::IndexOutOfBounds), - AccessType::Index(_) => Ok(Self::Felt), + AccessType::Index(_) => Ok(Self::Scalar(sty)), AccessType::Matrix(..) => Err(InvalidAccessError::IndexIntoScalar), _ => unreachable!(), }, - Self::Matrix(rows, cols) => match access_type { + Self::Matrix(sty, rows, cols) => match access_type { AccessType::Slice(range) => { let slice_range = range.to_slice_range(); if slice_range.end > rows { Err(InvalidAccessError::IndexOutOfBounds) } else { - Ok(Self::Matrix(slice_range.len(), cols)) + Ok(Self::Matrix(sty, slice_range.len(), cols)) } }, AccessType::Index(idx) if idx >= rows => Err(InvalidAccessError::IndexOutOfBounds), - AccessType::Index(_) => Ok(Self::Vector(cols)), + AccessType::Index(_) => Ok(Self::Vector(sty, cols)), AccessType::Matrix(row, col) if row >= rows || col >= cols => { Err(InvalidAccessError::IndexOutOfBounds) }, - AccessType::Matrix(..) => Ok(Self::Felt), + AccessType::Matrix(..) => Ok(Self::Scalar(sty)), _ => unreachable!(), }, } } } -impl fmt::Display for Type { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - Self::Felt => f.write_str("felt"), - Self::Vector(n) => write!(f, "felt[{n}]"), - Self::Matrix(rows, cols) => write!(f, "felt[{rows}, {cols}]"), - } - } -} - -/// Represents the type signature of a function -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum FunctionType { - /// An evaluator function, which has no results, and has - /// a complex type signature due to the nature of trace bindings - Evaluator(Vec), - /// A standard function with one or more inputs, and a result - Function(Vec, Type), -} -impl FunctionType { - pub fn result(&self) -> Option { - match self { - Self::Evaluator(_) => None, - Self::Function(_, result) => Some(*result), - } - } -} diff --git a/parser/src/lexer/mod.rs b/parser/src/lexer/mod.rs index f354cf6f4..5683b30b9 100644 --- a/parser/src/lexer/mod.rs +++ b/parser/src/lexer/mod.rs @@ -153,9 +153,14 @@ pub enum Token { Match, Case, When, - Felt, With, + // SCALAR TYPES + // -------------------------------------------------------------------------------------------- + Felt, + Bool, + UInt, + // PUNCTUATION // -------------------------------------------------------------------------------------------- Quote, @@ -196,6 +201,8 @@ impl Token { "ev" => Self::Ev, "fn" => Self::Fn, "felt" => Self::Felt, + "bool" => Self::Bool, + "uint" => Self::UInt, "buses" => Self::Buses, "multiset" => Self::Multiset, "logup" => Self::Logup, @@ -275,6 +282,8 @@ impl fmt::Display for Token { Self::Ev => write!(f, "ev"), Self::Fn => write!(f, "fn"), Self::Felt => write!(f, "felt"), + Self::Bool => write!(f, "bool"), + Self::UInt => write!(f, "uint"), Self::Buses => write!(f, "buses"), Self::Multiset => write!(f, "multiset"), Self::Logup => write!(f, "logup"), diff --git a/parser/src/parser/grammar.lalrpop b/parser/src/parser/grammar.lalrpop index 7f9ca3a44..e0677450d 100644 --- a/parser/src/parser/grammar.lalrpop +++ b/parser/src/parser/grammar.lalrpop @@ -241,9 +241,15 @@ FunctionBinding: (Identifier, Type) = { } FunctionBindingType: Type = { - "felt" => Type::Felt, - "felt" => Type::Vector(size as usize), - "felt" "[" "," "]" => Type::Matrix(row_size as usize, col_size as usize), + => Type::Scalar(sty), + => Type::Vector(sty, size as usize), + "[" "," "]" => Type::Matrix(sty, row_size as usize, col_size as usize), +} + +ScalarType: Option = { + "felt" => Some(ScalarType::Felt), + "uint" => Some(ScalarType::UInt), + "bool" => Some(ScalarType::Bool), } FunctionBody: Vec = { @@ -677,6 +683,8 @@ extern { "when" => Token::When, "with" => Token::With, "felt" => Token::Felt, + "uint" => Token::UInt, + "bool" => Token::Bool, "'" => Token::Quote, "=" => Token::Equal, "+" => Token::Plus, diff --git a/parser/src/parser/tests/constant_propagation.rs b/parser/src/parser/tests/constant_propagation.rs index 8096d0542..5c60f31be 100644 --- a/parser/src/parser/tests/constant_propagation.rs +++ b/parser/src/parser/tests/constant_propagation.rs @@ -76,22 +76,22 @@ fn test_constant_propagation() { // enf a.first = 1 expected .boundary_constraints - .push(enforce!(eq!(bounded_access!(a, Boundary::First, Type::Felt), int!(1)))); + .push(enforce!(eq!(bounded_access!(a, Boundary::First, ty!(felt).unwrap()), int!(1)))); // When constant propagation is done, the integrity constraints should look like: // enf test_constraint(b) // enf a + 4 = c + 5 expected .integrity_constraints - .push(enforce!(call!(lib::test_constraint(expr!(access!(b, Type::Vector(2))))))); + .push(enforce!(call!(lib::test_constraint(expr!(access!(b, ty!(felt[2]).unwrap())))))); expected.integrity_constraints.push(enforce!(eq!( - add!(access!(a, Type::Felt), int!(4)), - add!(access!(c, Type::Felt), int!(5)) + add!(access!(a, ty!(felt).unwrap()), int!(4)), + add!(access!(c, ty!(felt).unwrap()), int!(5)) ))); // The test_constraint function should look like: // enf b0 + 2 = b1 + 4 let body = vec![enforce!(eq!( - add!(access!(b0, Type::Felt), int!(2)), - add!(access!(b1, Type::Felt), int!(4)) + add!(access!(b0, ty!(felt).unwrap()), int!(2)), + add!(access!(b1, ty!(felt).unwrap()), int!(4)) ))]; expected.evaluators.insert( function_ident!(lib, test_constraint), diff --git a/parser/src/parser/tests/functions.rs b/parser/src/parser/tests/functions.rs index 2e5efe9ec..76fdcd5ca 100644 --- a/parser/src/parser/tests/functions.rs +++ b/parser/src/parser/tests/functions.rs @@ -21,8 +21,8 @@ fn fn_def_with_scalars() { Function::new( SourceSpan::UNKNOWN, function_ident!(fn_with_scalars), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(add!(access!(a), access!(b))))], ), ); @@ -44,8 +44,8 @@ fn fn_def_with_vectors() { Function::new( SourceSpan::UNKNOWN, function_ident!(fn_with_vectors), - vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], - Type::Vector(12), + vec![(ident!(a), ty!(felt[12]).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt[12]).unwrap(), vec![return_!(expr!(lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!(access!(x), access!(y)))))], ), @@ -85,8 +85,8 @@ fn fn_use_scalars_and_vectors() { Function::new( SourceSpan::UNKNOWN, function_ident!(fn_with_scalars_and_vectors), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(call!(sum(expr!( lc!(((x, expr!(access!(b)))) => add!(access!(a), access!(x))) )))))], @@ -150,8 +150,8 @@ fn fn_call_in_fn() { Function::new( SourceSpan::UNKNOWN, function_ident!(fold_vec), - vec![(ident!(a), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(call!(sum(expr!(lc!(((x, expr!(access!(a)))) => access!(x)))))))], ), ); @@ -161,8 +161,8 @@ fn fn_call_in_fn() { Function::new( SourceSpan::UNKNOWN, function_ident!(fold_scalar_and_vec), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(add!(access!(a), call!(fold_vec(expr!(access!(b)))))))], ), ); @@ -230,8 +230,8 @@ fn fn_call_in_ev() { Function::new( SourceSpan::UNKNOWN, function_ident!(fold_vec), - vec![(ident!(a), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(call!(sum(expr!(lc!(((x, expr!(access!(a)))) => access!(x)))))))], ), ); @@ -241,8 +241,8 @@ fn fn_call_in_ev() { Function::new( SourceSpan::UNKNOWN, function_ident!(fold_scalar_and_vec), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(add!(access!(a), call!(fold_vec(expr!(access!(b)))))))], ), ); @@ -313,8 +313,8 @@ fn fn_as_lc_iterables() { Function::new( SourceSpan::UNKNOWN, function_ident!(operation), - vec![(ident!(a), Type::Felt), (ident!(b), Type::Felt)], - Type::Felt, + vec![(ident!(a), ty!(felt).unwrap()), (ident!(b), ty!(felt).unwrap())], + ty!(felt).unwrap(), vec![let_!(x = expr!(add!(exp!(access!(a), access!(b)), int!(1))) => return_!(expr!(exp!(access!(b), access!(x)))))], ), @@ -382,8 +382,8 @@ fn fn_call_in_binary_ops() { Function::new( SourceSpan::UNKNOWN, function_ident!(operation), - vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], - Type::Felt, + vec![(ident!(a), ty!(felt[12]).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt).unwrap(), vec![return_!(expr!(call!(sum(expr!( lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!( access!(x), @@ -456,8 +456,8 @@ fn fn_call_in_vector_def() { Function::new( SourceSpan::UNKNOWN, function_ident!(operation), - vec![(ident!(a), Type::Vector(12)), (ident!(b), Type::Vector(12))], - Type::Vector(12), + vec![(ident!(a), ty!(felt[12]).unwrap()), (ident!(b), ty!(felt[12]).unwrap())], + ty!(felt[12]).unwrap(), vec![return_!(expr!(lc!(((x, expr!(access!(a))), (y, expr!(access!(b)))) => add!( access!(x), access!(y) diff --git a/parser/src/parser/tests/mod.rs b/parser/src/parser/tests/mod.rs index bd4116931..a8a4c6b8d 100644 --- a/parser/src/parser/tests/mod.rs +++ b/parser/src/parser/tests/mod.rs @@ -696,15 +696,16 @@ fn full_air_file() { // enf clk' = clk + 1 // } expected.integrity_constraints.push(enforce!(eq!( - access!(clk, 1, Type::Felt), - add!(access!(clk, Type::Felt), int!(1)) + access!(clk, 1, ty!(felt).unwrap()), + add!(access!(clk, ty!(felt).unwrap()), int!(1)) ))); // boundary_constraints { // enf clk.first = 0 // } - expected - .boundary_constraints - .push(enforce!(eq!(bounded_access!(clk, Boundary::First, Type::Felt), int!(0)))); + expected.boundary_constraints.push(enforce!(eq!( + bounded_access!(clk, Boundary::First, ty!(felt).unwrap()), + int!(0) + ))); ParseTest::new().expect_program_ast_from_file("src/parser/tests/input/system.air", expected); } diff --git a/parser/src/parser/tests/modules.rs b/parser/src/parser/tests/modules.rs index 66c52480d..57d903cc0 100644 --- a/parser/src/parser/tests/modules.rs +++ b/parser/src/parser/tests/modules.rs @@ -68,10 +68,10 @@ fn modules_integration_test() { vec![trace_segment!(0, "%0", [(clk, 1)])], vec![enforce_if!(match_arm!( eq!( - access!(clk, 1, Type::Felt), - add!(access!(clk, Type::Felt), access!(bar, k0, Type::Felt)) + access!(clk, 1, ty!(felt).unwrap()), + add!(access!(clk, ty!(felt).unwrap()), access!(bar, k0, ty!(felt).unwrap())) ), - access!(bar, k0, Type::Felt) + access!(bar, k0, ty!(felt).unwrap()) ))], ), ); @@ -85,8 +85,11 @@ fn modules_integration_test() { ident!(foo_constraint), vec![trace_segment!(0, "%0", [(clk, 1)])], vec![enforce_if!(match_arm!( - eq!(access!(clk, 1, Type::Felt), add!(access!(clk, Type::Felt), int!(1))), - access!(foo, k0, Type::Felt) + eq!( + access!(clk, 1, ty!(felt).unwrap()), + add!(access!(clk, ty!(felt).unwrap()), int!(1)) + ), + access!(foo, k0, ty!(felt).unwrap()) ))], ), ); @@ -95,13 +98,14 @@ fn modules_integration_test() { .insert(ident!(inputs), PublicInput::new_vector(SourceSpan::UNKNOWN, ident!(inputs), 2)); expected .integrity_constraints - .push(enforce!(call!(foo::foo_constraint(vector!(access!(clk, Type::Felt)))))); + .push(enforce!(call!(foo::foo_constraint(vector!(access!(clk, ty!(felt).unwrap())))))); expected .integrity_constraints - .push(enforce!(call!(bar::bar_constraint(vector!(access!(clk, Type::Felt)))))); - expected - .boundary_constraints - .push(enforce!(eq!(bounded_access!(clk, Boundary::First, Type::Felt), int!(0)))); + .push(enforce!(call!(bar::bar_constraint(vector!(access!(clk, ty!(felt).unwrap())))))); + expected.boundary_constraints.push(enforce!(eq!( + bounded_access!(clk, Boundary::First, ty!(felt).unwrap()), + int!(0) + ))); ParseTest::new() .expect_program_ast_from_file("src/parser/tests/input/import_example.air", expected); diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index 2659634ef..06dcd9d5d 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -1,6 +1,10 @@ use std::fmt; -use crate::ast::{AccessType, BusType, FunctionType, InvalidAccessError, TraceBinding, Type}; +use typing::*; + +use crate::ast::{ + Access, AccessType, BusType, FunctionType, InvalidAccessError, TraceBinding, TraceSegment, Type, +}; /// This type provides type and contextual information about a binding, /// i.e. not only does it tell us the type of a binding, but what type @@ -20,6 +24,7 @@ pub enum BindingType { /// /// The result type is None if the function is an evaluator Function(FunctionType), + Evaluator(Vec), /// A binding to a bus definition Bus(BusType), /// A function parameter corresponding to trace columns @@ -33,22 +38,27 @@ pub enum BindingType { /// A direct reference to a periodic column PeriodicColumn(usize), } -impl BindingType { + +impl Typing for BindingType { /// Get the value type of this binding, if applicable - pub fn ty(&self) -> Option { + fn ty(&self) -> Option { match self { - Self::TraceColumn(tb) | Self::TraceParam(tb) => Some(tb.ty()), - Self::Vector(elems) => Some(Type::Vector(elems.len())), + Self::TraceColumn(tb) | Self::TraceParam(tb) => tb.ty(), + Self::Vector(elems) => elems.ty(), Self::Alias(aliased) => aliased.ty(), Self::Local(ty) | Self::Constant(ty) | Self::PublicInput(ty) => Some(*ty), - Self::PeriodicColumn(_) => Some(Type::Felt), + Self::PeriodicColumn(_) => ty!(felt), Self::Function(ty) => ty.result(), - Self::Bus(_) => Some(Type::Felt), + Self::Evaluator(_) => None, + Self::Bus(_) => ty!(felt), } } +} +impl Access for BindingType { + type Accessed = Self; /// Produce a new [BindingType] which represents accessing the current binding via `access_type` - pub fn access(&self, access_type: AccessType) -> Result { + fn access(&self, access_type: AccessType) -> Result { match self { Self::Alias(aliased) => aliased.access(access_type), Self::Local(ty) => ty.access(access_type).map(Self::Local), @@ -81,19 +91,22 @@ impl BindingType { AccessType::Default => Ok(Self::PeriodicColumn(*period)), _ => Err(InvalidAccessError::IndexIntoScalar), }, - Self::Function(_) => Err(InvalidAccessError::InvalidBinding), + Self::Function(_) | Self::Evaluator(_) => Err(InvalidAccessError::InvalidBinding), Self::Bus(bus) => Ok(Self::Bus(*bus)), } } } + impl fmt::Display for BindingType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // TODO: Update to reflect the type signature match self { Self::Alias(aliased) => write!(f, "{aliased}"), Self::Local(_) => f.write_str("local"), Self::Constant(_) => f.write_str("constant"), Self::Vector(_) => f.write_str("vector"), Self::Function(_) => f.write_str("function"), + Self::Evaluator(_) => f.write_str("evaluator"), Self::TraceColumn(_) | Self::TraceParam(_) => f.write_str("trace column(s)"), Self::PublicInput(_) => f.write_str("public input(s)"), Self::PeriodicColumn(_) => f.write_str("periodic column(s)"), diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index b14773d54..8cd734327 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -168,7 +168,7 @@ impl VisitMut for SemanticAnalysis<'_> { name: Some(segment.name), offset: 0, size: segment.size, - ty: Type::Vector(segment.size), + ty: ty!(felt[segment.size]).unwrap(), }) ), None @@ -194,7 +194,7 @@ impl VisitMut for SemanticAnalysis<'_> { assert_eq!( self.locals.insert( NamespacedIdentifier::Binding(input.name()), - BindingType::PublicInput(Type::Vector(input.size())) + BindingType::PublicInput(ty!(felt[input.size()]).unwrap()) ), None ); @@ -215,7 +215,8 @@ impl VisitMut for SemanticAnalysis<'_> { } // It should be impossible for there to be a local by this name at this point assert_eq!( - self.locals.insert(namespaced_name, BindingType::Constant(constant.ty())), + self.locals + .insert(namespaced_name, BindingType::Constant(constant.ty().unwrap())), None ); } @@ -229,10 +230,8 @@ impl VisitMut for SemanticAnalysis<'_> { self.declaration_import_conflict(namespaced_name.span(), prev.span())?; } assert_eq!( - self.locals.insert( - namespaced_name, - BindingType::Function(FunctionType::Evaluator(function.params.clone())) - ), + self.locals + .insert(namespaced_name, BindingType::Function(function.fn_ty.clone())), None ); } @@ -243,13 +242,8 @@ impl VisitMut for SemanticAnalysis<'_> { self.declaration_import_conflict(namespaced_name.span(), prev.span())?; } assert_eq!( - self.locals.insert( - namespaced_name, - BindingType::Function(FunctionType::Function( - function.param_types(), - function.return_type - )) - ), + self.locals + .insert(namespaced_name, BindingType::Function(function.fn_ty.clone())), None ); } @@ -620,7 +614,7 @@ impl VisitMut for SemanticAnalysis<'_> { // If we were unable to determine a type for any of the bindings, use a large vector as a // placeholder - let expected = BindingType::Local(result_ty.unwrap_or(Type::Vector(u32::MAX as usize))); + let expected = BindingType::Local(result_ty.unwrap_or(ty!(_[u32::MAX as usize]).unwrap())); // Bind everything now, resolving any deferred types using our fallback expected type for (binding, _, binding_ty) in binding_tys.drain(..) { @@ -644,8 +638,8 @@ impl VisitMut for SemanticAnalysis<'_> { // Store the result type of this comprehension result_ty = match result_ty { - Some(Type::Vector(_)) => result_ty, - Some(Type::Matrix(rows, _)) => Some(Type::Vector(rows)), + Some(Type::Vector(_, _)) => result_ty, + Some(Type::Matrix(sty, rows, _)) => ty!(sty[rows]), _ => None, }; expr.ty = result_ty; @@ -719,7 +713,7 @@ impl VisitMut for SemanticAnalysis<'_> { // * Must be trace bindings or aliases of same // * Must match the type signature of the callee if let Ok(ty) = callee_binding_ty { - if let BindingType::Function(FunctionType::Evaluator(ref params)) = ty.item { + if let BindingType::Evaluator(params) = ty.item { for (arg, param) in expr.args.iter().zip(params.iter()) { self.validate_evaluator_argument(expr.span(), arg, param)?; } @@ -737,25 +731,40 @@ impl VisitMut for SemanticAnalysis<'_> { self.visit_mut_scalar_expr(expr.rhs.as_mut())?; // Validate the operand types - match (expr.lhs.ty(), expr.rhs.ty()) { - (Ok(Some(lty)), Ok(Some(rty))) => { - if lty != rty { - self.has_type_errors = true; - // Note: We don't break here but at the end of the module's compilation, as we - // want to continue to gather as many errors as possible - let _ = self.type_mismatch( - Some(<y), - expr.lhs.span(), - &rty, - expr.rhs.span(), - expr.span(), - ); - } + match expr.bin_ty.infer_ty() { + Err(err) => { + self.has_type_errors = true; + // Note: We don't break here but at the end of the module's compilation, as we + // want to continue to gather as many errors as possible + self.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid binary expression") + .with_primary_label(expr.span(), format!("{err}")) + .emit(); ControlFlow::Continue(()) }, - _ => ControlFlow::Continue(()), + Ok(_) => ControlFlow::Continue(()), } } + // match (expr.lhs.ty(), expr.rhs.ty()) { + // (Some(lty), Some(rty)) => { + // if lty != rty { + // self.has_type_errors = true; + // // Note: We don't break here but at the end of the module's compilation, as we + // // want to continue to gather as many errors as possible + // let _ = self.type_mismatch( + // Some(<y), + // expr.lhs.span(), + // &rty, + // expr.rhs.span(), + // expr.span(), + // ); + // } + // ControlFlow::Continue(()) + // }, + // _ => ControlFlow::Continue(()), + // } + // } fn visit_mut_range_bound( &mut self, @@ -799,8 +808,8 @@ impl VisitMut for SemanticAnalysis<'_> { .with_primary_label( expr.span(), format!( - "constant is not a valid range bound: expected scalar, got {}", - const_expr.ty() + "constant is not a valid range bound: expected uint, got {}", + const_expr.show_ty() ), ) .emit(); @@ -942,7 +951,9 @@ impl VisitMut for SemanticAnalysis<'_> { // be captured as a vector of size 1 AccessType::Slice(ref range) => { let range = range.to_slice_range(); - assert_eq!(expr.ty.replace(Type::Vector(range.len())), None) + let sty = expr.ty.scalar_ty(); + let new_ty = ty!(sty[range.len()]).unwrap(); + assert_eq!(expr.ty.replace(new_ty), None) }, // All other access types can be derived from the binding type _ => assert_eq!(expr.ty.replace(binding_ty.ty().unwrap()), None), @@ -958,14 +969,14 @@ impl VisitMut for SemanticAnalysis<'_> { .with_secondary_label(derived_from, "references this declaration") .emit(); // Continue with a fabricated type - let ty = match &expr.access_type { + let new_ty = match &expr.access_type { AccessType::Slice(range) => { let range = range.to_slice_range(); - Type::Vector(range.len()) + ty!(felt[range.len()]).unwrap() }, - _ => Type::Felt, + _ => ty!(felt).unwrap(), }; - assert_eq!(expr.ty.replace(ty), None); + assert_eq!(expr.ty.replace(new_ty), None); ControlFlow::Continue(()) }, } @@ -1010,6 +1021,7 @@ impl VisitMut for SemanticAnalysis<'_> { // These binding types are module-local declarations BindingType::Constant(_) | BindingType::Function(_) + | BindingType::Evaluator(_) | BindingType::PeriodicColumn(_) => { *expr = ResolvableIdentifier::Resolved(QualifiedIdentifier::new( current_module, @@ -1227,9 +1239,9 @@ impl SemanticAnalysis<'_> { // Note: We don't break here but at the end of the module's compilation, // as we want to continue to gather as many errors as possible let _ = self.type_mismatch( - Some(&Type::Vector(param.size)), + ty!(_[param.size]).as_ref(), arg.span(), - &Type::Vector(size), + &ty!(_[size]).unwrap(), param.span(), span, ); @@ -1243,7 +1255,7 @@ impl SemanticAnalysis<'_> { param.id, 0, param.size, - Type::Vector(param.size), + ty!(felt[param.size]).unwrap(), )); // Note: We don't break here but at the end of the module's compilation, as // we want to continue to gather as many errors as possible @@ -1367,9 +1379,9 @@ impl SemanticAnalysis<'_> { } else { let inferred = tb.ty(); return self.type_mismatch( - Some(&inferred), + inferred.as_ref(), access.span(), - &Type::Felt, + &ty!(felt).unwrap(), ty.span(), constraint_span, ); @@ -1386,7 +1398,7 @@ impl SemanticAnalysis<'_> { 0, 0, 1, - Type::Felt, + ty!(felt).unwrap(), )); return self.binding_mismatch( &aty, @@ -1483,7 +1495,7 @@ impl SemanticAnalysis<'_> { self.type_mismatch( Some(ty), access.span(), - &Type::Felt, + &ty!(_).unwrap(), found.span(), constraint_span, )?; @@ -1503,7 +1515,7 @@ impl SemanticAnalysis<'_> { self.type_mismatch( access.ty.as_ref(), access.span(), - &Type::Felt, + &ty!(_).unwrap(), access.name.span(), constraint_span, )?; @@ -1776,9 +1788,11 @@ impl SemanticAnalysis<'_> { fn expr_binding_type(&self, expr: &Expr) -> Result { match expr { - Expr::Const(constant) => Ok(BindingType::Local(constant.ty())), + Expr::Const(constant) => { + Ok(BindingType::Local(constant.ty().expect("constant type should be known"))) + }, Expr::Range(range) => { - Ok(BindingType::Local(Type::Vector(range.to_slice_range().len()))) + Ok(BindingType::Local(ty!(uint[range.to_slice_range().len()]).unwrap())) }, Expr::Vector(elems) => { let mut binding_tys = Vec::with_capacity(elems.len()); @@ -1789,14 +1803,12 @@ impl SemanticAnalysis<'_> { Ok(BindingType::Vector(binding_tys)) }, Expr::Matrix(expr) => { - let rows = expr.len(); - let columns = expr[0].len(); - Ok(BindingType::Local(Type::Matrix(rows, columns))) + Ok(BindingType::Local(expr.ty().expect("matrix type should be known"))) }, Expr::SymbolAccess(expr) => self.access_binding_type(expr), Expr::Call(Call { ty: None, .. }) => Err(InvalidAccessError::InvalidBinding), Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)), - Expr::Binary(_) => Ok(BindingType::Local(Type::Felt)), + Expr::Binary(be) => Ok(BindingType::Local(be.ty().or(ty!(felt)).unwrap())), Expr::ListComprehension(lc) => { match lc.ty { Some(ty) => Ok(BindingType::Local(ty)), @@ -1816,8 +1828,9 @@ impl SemanticAnalysis<'_> { .emit(); Err(InvalidAccessError::InvalidBinding) }, - Expr::BusOperation(_expr) => Ok(BindingType::Local(Type::Felt)), - Expr::Null(_) | Expr::Unconstrained(_) => Ok(BindingType::Local(Type::Felt)), + Expr::BusOperation(_) | Expr::Null(_) | Expr::Unconstrained(_) => { + Ok(BindingType::Local(ty!(felt).unwrap())) + }, } } @@ -1882,8 +1895,7 @@ impl SemanticAnalysis<'_> { // it elsewhere. For the time being, functions are not // implemented, so the only place this comes up is with these // list folding builtins - let folder_ty = - FunctionType::Function(vec![Type::Vector(usize::MAX)], Type::Felt); + let folder_ty = FunctionType::Function(vec![ty!(felt[usize::MAX])], ty!(felt)); Ok(Span::new(qid.span(), BindingType::Function(folder_ty))) }, name => unimplemented!("unsupported builtin: {}", name), @@ -1896,14 +1908,17 @@ impl SemanticAnalysis<'_> { imported_from .constants .get(qid.as_ref()) - .map(|c| Span::new(c.span(), BindingType::Constant(c.ty()))) + .map(|c| { + Span::new( + c.span(), + BindingType::Constant(c.ty().expect("constant type should be known")), + ) + }) .or_else(|| { - imported_from.evaluators.get(qid.as_ref()).map(|e| { - Span::new( - e.span(), - BindingType::Function(FunctionType::Evaluator(e.params.clone())), - ) - }) + imported_from + .evaluators + .get(qid.as_ref()) + .map(|e| Span::new(e.span(), BindingType::Evaluator(e.params.clone()))) }) .ok_or(InvalidAccessError::UndefinedVariable) } diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index 216785af1..9bff10304 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -443,14 +443,8 @@ impl VisitMut for ConstantPropagation<'_> { } if is_constant { - let ty = match vector.first().and_then(|e| e.ty()).unwrap() { - Type::Felt => Type::Vector(vector.len()), - Type::Vector(n) => Type::Matrix(vector.len(), n), - _ => unreachable!(), - }; - - let new_expr = match ty { - Type::Vector(_) => ConstantExpr::Vector( + let new_expr = match vector.ty().expect("vector type must be known") { + Type::Vector(_, _) => ConstantExpr::Vector( vector .iter() .map(|expr| match expr { From e7fadd4cbcff36d745525921c7ff6ec9c3c33623 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Fri, 1 Aug 2025 16:45:56 +0200 Subject: [PATCH 15/42] chores: cargo clippy + fmt --- mir/src/ir/mod.rs | 1 + mir/src/ir/nodes/ops/parameter.rs | 2 +- mir/src/ir/nodes/ops/value.rs | 5 +- mir/src/passes/translate.rs | 2 +- mir/src/passes/unrolling.rs | 67 ++++++++++--------- parser/src/ast/types.rs | 4 +- parser/src/sema/semantic_analysis.rs | 6 +- parser/src/transforms/constant_propagation.rs | 2 +- 8 files changed, 44 insertions(+), 45 deletions(-) diff --git a/mir/src/ir/mod.rs b/mir/src/ir/mod.rs index 86b7d8630..64cb5356a 100644 --- a/mir/src/ir/mod.rs +++ b/mir/src/ir/mod.rs @@ -19,6 +19,7 @@ pub use node::Node; pub use nodes::*; pub use owner::Owner; pub use quad_eval::{QuadFelt, RandomInputs}; +#[allow(unused_imports)] pub use typing::*; pub use utils::*; /// A trait for nodes that can have children diff --git a/mir/src/ir/nodes/ops/parameter.rs b/mir/src/ir/nodes/ops/parameter.rs index 4a1bdfe44..76e552fd5 100644 --- a/mir/src/ir/nodes/ops/parameter.rs +++ b/mir/src/ir/nodes/ops/parameter.rs @@ -1,9 +1,9 @@ use std::hash::{Hash, Hasher}; use miden_diagnostics::{SourceSpan, Spanned}; +use typing::*; use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Singleton}; -use typing::*; /// A MIR operation to represent a `Parameter` in a function or evaluator. /// Also used in If and For loops to represent declared parameters. diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index 868f8b621..468cb5dec 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -1,10 +1,7 @@ -use air_parser::ast::{ - self, BusType, Identifier, QualifiedIdentifier, TraceColumnIndex, TraceSegmentId, -}; +use air_parser::ast::{BusType, Identifier, QualifiedIdentifier, TraceColumnIndex, TraceSegmentId}; use miden_diagnostics::{SourceSpan, Spanned}; use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Singleton}; -use typing::*; /// A MIR operation to represent a known value, [Value]. /// diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 282bb0ee1..64ac61832 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -1160,7 +1160,7 @@ impl<'a> MirBuilder<'a> { // // In that case, replacing the default type (Felt) with the one from the access if let Some(mut param) = let_bound_access_expr.as_parameter_mut() { - if let Some(_) = &access.ty { + if access.ty.is_some() { param.ty = access.ty } } diff --git a/mir/src/passes/unrolling.rs b/mir/src/passes/unrolling.rs index 283d6092f..517787f55 100644 --- a/mir/src/passes/unrolling.rs +++ b/mir/src/passes/unrolling.rs @@ -1083,47 +1083,48 @@ impl Visitor for UnrollingSecondPass<'_> { let new_node = self.nodes_to_replace.get(&body.get_ptr()).unwrap().1.clone(); // If there is a selector, we need to enforce it on the body - let new_node_with_selector_if_needed = - if let Some(selector) = self.for_inlining_context.clone().unwrap().selector { - if let Op::Vector(new_node_vector) = new_node.borrow().deref() { - let new_node_vec = new_node_vector.children().borrow().deref().clone(); - let mut new_vec = vec![]; - for new_node_child in new_node_vec.into_iter() { - let zero_node = Value::create(SpannedMirValue { - span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(0)), - }); - // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> - // Enf(Sub(lhs, rhs) == 0), but it introduces an - // unnecessary zero node - let new_node_child_with_selector = Sub::create( - Mul::create( - duplicate_node(selector.clone(), &mut HashMap::new()), - new_node_child, - root.span(), - ), - zero_node, - root.span(), - ); - new_vec.push(new_node_child_with_selector); - } - Vector::create(new_vec, root.span()) - } else { + let new_node_with_selector_if_needed = if let Some(selector) = + self.for_inlining_context.clone().unwrap().selector + { + if let Op::Vector(new_node_vector) = new_node.borrow().deref() { + let new_node_vec = new_node_vector.children().borrow().deref().clone(); + let mut new_vec = vec![]; + for new_node_child in new_node_vec.into_iter() { let zero_node = Value::create(SpannedMirValue { span: Default::default(), value: MirValue::Constant(ConstantValue::Felt(0)), }); - // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> Enf(Sub(lhs, - // rhs) == 0), but it introduces an unnecessary zero node - Sub::create( - Mul::create(selector, new_node, root.span()), + // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> + // Enf(Sub(lhs, rhs) == 0), but it introduces an + // unnecessary zero node + let new_node_child_with_selector = Sub::create( + Mul::create( + duplicate_node(selector.clone(), &mut HashMap::new()), + new_node_child, + root.span(), + ), zero_node, root.span(), - ) + ); + new_vec.push(new_node_child_with_selector); } + Vector::create(new_vec, root.span()) } else { - new_node - }; + let zero_node = Value::create(SpannedMirValue { + span: Default::default(), + value: MirValue::Constant(ConstantValue::Felt(0)), + }); + // FIXME: The Sub here is used to keep the form of Eq(lhs, rhs) -> Enf(Sub(lhs, + // rhs) == 0), but it introduces an unnecessary zero node + Sub::create( + Mul::create(selector, new_node, root.span()), + zero_node, + root.span(), + ) + } + } else { + new_node + }; root.as_op().unwrap().set(&new_node_with_selector_if_needed); diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index 691ee8c0a..f86653793 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -1,6 +1,6 @@ +pub use typing::{bty, fty, kind, sty, tty, ty, tys, *}; + use super::*; -pub use typing::*; -pub use typing::{bty, fty, kind, sty, tty, ty, tys}; impl Access for Type { type Accessed = Self; diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 8cd734327..e882b2267 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -638,7 +638,7 @@ impl VisitMut for SemanticAnalysis<'_> { // Store the result type of this comprehension result_ty = match result_ty { - Some(Type::Vector(_, _)) => result_ty, + Some(Type::Vector(..)) => result_ty, Some(Type::Matrix(sty, rows, _)) => ty!(sty[rows]), _ => None, }; @@ -750,8 +750,8 @@ impl VisitMut for SemanticAnalysis<'_> { // (Some(lty), Some(rty)) => { // if lty != rty { // self.has_type_errors = true; - // // Note: We don't break here but at the end of the module's compilation, as we - // // want to continue to gather as many errors as possible + // // Note: We don't break here but at the end of the module's compilation, as + // we // want to continue to gather as many errors as possible // let _ = self.type_mismatch( // Some(<y), // expr.lhs.span(), diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index 9bff10304..e6db48023 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -444,7 +444,7 @@ impl VisitMut for ConstantPropagation<'_> { if is_constant { let new_expr = match vector.ty().expect("vector type must be known") { - Type::Vector(_, _) => ConstantExpr::Vector( + Type::Vector(..) => ConstantExpr::Vector( vector .iter() .map(|expr| match expr { From ee8b4a7e6ed41f68370159e0e7c4192d21393b17 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 5 Aug 2025 10:15:56 +0200 Subject: [PATCH 16/42] fix(typing): fix access::Default on TraceBinding --- parser/src/ast/trace.rs | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index 333b5dcd6..b873417a9 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -267,7 +267,11 @@ impl Access for TraceBinding { /// Derive a new [TraceBinding] derived from the current one given an [AccessType] fn access(&self, access_type: AccessType) -> Result { match access_type { - AccessType::Default => Ok(*self), + // + AccessType::Default => match self.size { + 1 => Ok(Self { ty: ty!(felt).unwrap(), ..*self }), + _ => Ok(*self), + }, AccessType::Slice(_) if self.is_scalar() => Err(InvalidAccessError::SliceOfScalar), AccessType::Slice(range) => { let slice_range = range.to_slice_range(); From 6fee06e07cd1f3ce330baae1e1fab7cc26da9a6f Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 5 Aug 2025 10:47:37 +0200 Subject: [PATCH 17/42] fix(typing): fix TraceSegment::kind() --- parser/src/ast/declarations.rs | 12 ++++++++++-- parser/src/ast/trace.rs | 23 +++++++---------------- 2 files changed, 17 insertions(+), 18 deletions(-) diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index e0d824698..fe6df2f63 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -393,8 +393,8 @@ impl EvaluatorFunction { params: Vec, body: Vec, ) -> Self { - let p = params.iter().map(|ty| ty.ty()).collect::>(); - let fn_ty = FunctionType::Evaluator(p); + let param_tys = params.iter().map(|ty| ty.ty()).collect::>(); + let fn_ty = FunctionType::Evaluator(param_tys); Self { span, name, params, body, fn_ty } } } @@ -404,6 +404,14 @@ impl PartialEq for EvaluatorFunction { self.name == other.name && self.params == other.params && self.body == other.body } } +impl Typing for EvaluatorFunction { + fn ty(&self) -> Option { + None + } + fn kind(&self) -> Option { + Some(Kind::Callable(self.fn_ty.clone())) + } +} /// Functions take a group of expressions as parameters and returns a value. /// diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index b873417a9..5a70ef2bd 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -1,7 +1,7 @@ use std::fmt; use miden_diagnostics::{SourceSpan, Spanned}; -use typing::{FunctionType, Kind, Typing, tty, ty}; +use typing::{Kind, Typing, tty, ty}; use super::*; @@ -28,7 +28,6 @@ pub struct TraceSegment { /// A vector of `size` elements which tracks for every column whether a /// constraint has been applied to that column, and on what boundaries. pub boundary_constrained: Vec>, - pub fn_ty: Option, } impl TraceSegment { /// Constructs a new [TraceSegment] given a span, segment id, name, and a vector of (Identifier, @@ -58,7 +57,7 @@ impl TraceSegment { // The size of the segment is the sum of the sizes of all the bindings let size = offset; - let mut res = Self { + Self { span, id, name, @@ -68,13 +67,7 @@ impl TraceSegment { Span::new(SourceSpan::UNKNOWN, ColumnBoundaryFlags::EMPTY); size ], - fn_ty: None, - }; - res.fn_ty = match res.kind() { - Some(Kind::Callable(fty)) => Some(fty), - _ => None, - }; - res + } } /// Returns true if `column` is constrained on `boundary` @@ -110,12 +103,10 @@ impl TraceSegment { } impl Typing for TraceSegment { fn ty(&self) -> Option { - None - } - fn kind(&self) -> Option { - Some(Kind::Callable(FunctionType::Evaluator( - self.bindings.iter().map(|b| b.ty()).collect(), - ))) + match self.size { + 1 => self.bindings.first().map(|b| b.ty())?, + _ => ty!(felt[self.size]), + } } } impl fmt::Debug for TraceSegment { From 68b28ab4d99cabf82b086ce01ef1569935415d85 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 5 Aug 2025 12:36:13 +0200 Subject: [PATCH 18/42] feat(typing): update subtyping rules: make `?` and `_` top types --- typing/src/lib.rs | 171 +++++++++++++++++++++++----------------------- 1 file changed, 87 insertions(+), 84 deletions(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index 635a667f4..a4e99b544 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -126,8 +126,8 @@ pub trait Typing { self.is_vector() } /// Returns true if the shape of `self` is a sub-shape of the shape of `other` - /// The shapes are compatible if: - /// - self is `?` (None) + /// The shape of `self` is a sub-shape of the shape of `other` if: + /// - other is `?` (None) /// - both are scalars /// - both are vectors of the same length /// - both are vectors with one of the lengths being `u32::MAX` @@ -135,15 +135,16 @@ pub trait Typing { /// - both are matrices with one or more of the rows or columns being `u32::MAX`, the other pair /// (if any) being equal /// - /// self\\other || _[r,c] | _[l] | _ | ? - /// ============||========|======|===|== - /// _[r,c] || y | n | n | n - /// _[l] || n | y | n | n - /// _ || n | n | y | n - /// ? || y | y | y | y + /// self\\other || ? | _ | _[l] | _[r,c] + /// ============||===|===|======|======== + /// ? || y | n | n | n + /// _ || y | y | n | n + /// _[l] || y | n | y | n + /// _[r,c] || y | n | n | y + /// fn is_subshape(&self, other: &impl Typing) -> bool { match (self.ty(), other.ty()) { - (None, _) => true, + (_, None) => true, (Some(Type::Scalar(_)), Some(Type::Scalar(_))) => true, (Some(Type::Vector(_, len1)), Some(Type::Vector(_, len2))) => { len1 == len2 || len1 == u32::MAX as usize || len2 == u32::MAX as usize @@ -166,18 +167,18 @@ pub trait Typing { /// - both are matrices with one or more of the rows or columns being `u32::MAX`, the other pair /// (if any) being equal /// - /// self\\other || _[r,c] | _[l] | _ | ? - /// ============||========|======|===|== - /// _[r,c] || y | n | n | y - /// _[l] || n | y | n | y - /// _ || n | n | y | y - /// ? || y | y | y | y + /// self\\other || ? | _ | _[l] | _[r,c] + /// ============||===|===|======|======== + /// ? || y | y | y | y + /// _ || y | y | n | n + /// _[l] || y | n | y | n + /// _[r,c] || y | n | n | y /// /// This is a more relaxed version of [Typing::is_subshape], /// allowing for bi-directional compatibility checks. The only - /// difference is that it allows for `other` to be `?` (None). + /// difference is that it allows for `self` to be `?` (None). fn is_shape_compatible(&self, other: &impl Typing) -> bool { - other.ty().is_none() || self.is_subshape(other) + self.ty().is_none() || self.is_subshape(other) } /// Returns true if `self` is a subtype of `other` /// Notation: @@ -191,26 +192,27 @@ pub trait Typing { /// Integer type /// /// Subtyping rules: - /// - felt > bool > _ - /// - felt > uint > _ + /// - _ > felt > bool + /// - _ > felt > uint /// /// Which means: - /// - `_` is a subtype of all scalar types + /// - all scalar types are subtypes of `_` /// - `bool` is a subtype of `felt`: a `bool` is a `felt with a `is_bool` property /// - `uint` is a subtype of `felt`: a `uint` is a `felt` with the `constant` property /// - /// self\\other || felt | bool | uint | _ | - /// ============||======|======|======|===| - /// felt || y | n | n | n | - /// bool || y | y | n | n | - /// uint || y | n | y | n | - /// _ || y | y | y | y | + /// self\\other || _ | felt | bool | uint | + /// ============||===|======|======|======| + /// _ || y | n | n | n | + /// felt || y | y | n | n | + /// bool || y | y | y | n | + /// uint || y | y | n | y | fn is_scalar_subtype(&self, other: &impl Typing) -> bool { !matches!( (self.scalar_ty(), other.scalar_ty()), - (sty!(felt), sty!(bool) | sty!(uint) | sty!(_)) - | (sty!(bool), sty!(uint) | sty!(_)) - | (sty!(uint), sty!(bool) | sty!(_)) + (sty!(_), sty!(felt) | sty!(bool) | sty!(uint)) + | (sty!(felt), sty!(bool) | sty!(uint)) + | (sty!(bool), sty!(uint)) + | (sty!(uint), sty!(bool)) ) } /// Returns true if `self` is a subtype of `other` @@ -232,25 +234,26 @@ pub trait Typing { /// /// Subtyping rules: /// ? > _ > felt > bool - /// ... > felt > uint + /// ? > _ > felt > uint /// ? > _[l] > felt[l] > bool[l] - /// ... > felt[l] > uint[l] + /// ? > _[l] > felt[l] > uint[l] /// ? > _[r, c] > felt[r, c] > bool[r, c] - /// ... > felt[r, c] > uint[r, c] - /// Assuming shapes are compatible, this function checks if the scalar types, + /// ? > _[r, c] > felt[r, c] > uint[r, c] + /// Assuming the shape of `self` is a sub-shape of the shape of `other`, + /// this function checks if `self` is a subtype of `other`, /// with the added case of `?`, which all types are subtypes of. /// See [Typing::is_scalar_subtype] for a more detailed explanation /// of the subtyping rules of scalar types. /// - /// self\\other || felt | bool | uint | _ | ? | - /// ============||======|======|======|===|===| - /// felt ||[ y | n | n | n]| n | - /// bool ||[ y | y | n | n]| n | - /// uint ||[ y | n | y | n]| n | - /// _ ||[ y | y | y | y]| n | - /// ? || y | y | y | y | y | + /// self\\other || ? | _ | felt | bool | uint | + /// ============||===|===|======|======|======| + /// ? || y | n | n | n | n | + /// _ || y |[y | n | n | n]| + /// felt || y |[y | y | n | n]| + /// bool || y |[y | y | y | n]| + /// uint || y |[y | y | n | y]| /// - /// = self.is_scalar_subtype(other) | self == ? + /// = self.is_scalar_subtype(other) | other == ? /// [...] Denotes the result of the [Typing::is_scalar_subtype] method. fn is_subtype(&self, other: &impl Typing) -> bool { self.is_subshape(other) && self.is_scalar_subtype(other) @@ -675,24 +678,24 @@ mod tests { #[test] fn test_typing_subtype() { assert_subtype!(ty!(?); ty!(?)); - assert_subtype!(ty!(?); ty!(_)); - assert_subtype!(ty!(?); ty!(felt)); - assert_subtype!(ty!(?); ty!(bool)); - assert_subtype!(ty!(?); ty!(uint)); - assert_subtype!(ty!(?); ty!(_[5])); - assert_subtype!(ty!(?); ty!(felt[5])); - assert_subtype!(ty!(?); ty!(bool[5])); - assert_subtype!(ty!(?); ty!(uint[5])); - assert_subtype!(ty!(?); ty!(_[3, 4])); - assert_subtype!(ty!(?); ty!(felt[3, 4])); - assert_subtype!(ty!(?); ty!(bool[3, 4])); - assert_subtype!(ty!(?); ty!(uint[3, 4])); - - assert_subtype!(ty!(_); !ty!(?)); + assert_subtype!(ty!(?); !ty!(_)); + assert_subtype!(ty!(?); !ty!(felt)); + assert_subtype!(ty!(?); !ty!(bool)); + assert_subtype!(ty!(?); !ty!(uint)); + assert_subtype!(ty!(?); !ty!(_[5])); + assert_subtype!(ty!(?); !ty!(felt[5])); + assert_subtype!(ty!(?); !ty!(bool[5])); + assert_subtype!(ty!(?); !ty!(uint[5])); + assert_subtype!(ty!(?); !ty!(_[3, 4])); + assert_subtype!(ty!(?); !ty!(felt[3, 4])); + assert_subtype!(ty!(?); !ty!(bool[3, 4])); + assert_subtype!(ty!(?); !ty!(uint[3, 4])); + + assert_subtype!(ty!(_); ty!(?)); assert_subtype!(ty!(_); ty!(_)); - assert_subtype!(ty!(_); ty!(felt)); - assert_subtype!(ty!(_); ty!(bool)); - assert_subtype!(ty!(_); ty!(uint)); + assert_subtype!(ty!(_); !ty!(felt)); + assert_subtype!(ty!(_); !ty!(bool)); + assert_subtype!(ty!(_); !ty!(uint)); assert_subtype!(ty!(_); !ty!(_[5])); assert_subtype!(ty!(_); !ty!(felt[5])); assert_subtype!(ty!(_); !ty!(bool[5])); @@ -702,8 +705,8 @@ mod tests { assert_subtype!(ty!(_); !ty!(bool[3, 4])); assert_subtype!(ty!(_); !ty!(uint[3, 4])); - assert_subtype!(ty!(felt); !ty!(?)); - assert_subtype!(ty!(felt); !ty!(_)); + assert_subtype!(ty!(felt); ty!(?)); + assert_subtype!(ty!(felt); ty!(_)); assert_subtype!(ty!(felt); ty!(felt)); assert_subtype!(ty!(felt); !ty!(bool)); assert_subtype!(ty!(felt); !ty!(uint)); @@ -716,8 +719,8 @@ mod tests { assert_subtype!(ty!(felt); !ty!(bool[3, 4])); assert_subtype!(ty!(felt); !ty!(uint[3, 4])); - assert_subtype!(ty!(bool); !ty!(?)); - assert_subtype!(ty!(bool); !ty!(_)); + assert_subtype!(ty!(bool); ty!(?)); + assert_subtype!(ty!(bool); ty!(_)); assert_subtype!(ty!(bool); ty!(felt)); assert_subtype!(ty!(bool); ty!(bool)); assert_subtype!(ty!(bool); !ty!(uint)); @@ -730,8 +733,8 @@ mod tests { assert_subtype!(ty!(bool); !ty!(bool[3, 4])); assert_subtype!(ty!(bool); !ty!(uint[3, 4])); - assert_subtype!(ty!(uint); !ty!(?)); - assert_subtype!(ty!(uint); !ty!(_)); + assert_subtype!(ty!(uint); ty!(?)); + assert_subtype!(ty!(uint); ty!(_)); assert_subtype!(ty!(uint); ty!(felt)); assert_subtype!(ty!(uint); !ty!(bool)); assert_subtype!(ty!(uint); ty!(uint)); @@ -744,26 +747,26 @@ mod tests { assert_subtype!(ty!(uint); !ty!(bool[3, 4])); assert_subtype!(ty!(uint); !ty!(uint[3, 4])); - assert_subtype!(ty!(_[5]); !ty!(?)); + assert_subtype!(ty!(_[5]); ty!(?)); assert_subtype!(ty!(_[5]); !ty!(_)); assert_subtype!(ty!(_[5]); !ty!(felt)); assert_subtype!(ty!(_[5]); !ty!(bool)); assert_subtype!(ty!(_[5]); !ty!(uint)); assert_subtype!(ty!(_[5]); ty!(_[5])); - assert_subtype!(ty!(_[5]); ty!(felt[5])); - assert_subtype!(ty!(_[5]); ty!(bool[5])); - assert_subtype!(ty!(_[5]); ty!(uint[5])); + assert_subtype!(ty!(_[5]); !ty!(felt[5])); + assert_subtype!(ty!(_[5]); !ty!(bool[5])); + assert_subtype!(ty!(_[5]); !ty!(uint[5])); assert_subtype!(ty!(_[5]); !ty!(_[3, 4])); assert_subtype!(ty!(_[5]); !ty!(felt[3, 4])); assert_subtype!(ty!(_[5]); !ty!(bool[3, 4])); assert_subtype!(ty!(_[5]); !ty!(uint[3, 4])); - assert_subtype!(ty!(felt[5]); !ty!(?)); + assert_subtype!(ty!(felt[5]); ty!(?)); assert_subtype!(ty!(felt[5]); !ty!(_)); assert_subtype!(ty!(felt[5]); !ty!(felt)); assert_subtype!(ty!(felt[5]); !ty!(bool)); assert_subtype!(ty!(felt[5]); !ty!(uint)); - assert_subtype!(ty!(felt[5]); !ty!(_[5])); + assert_subtype!(ty!(felt[5]); ty!(_[5])); assert_subtype!(ty!(felt[5]); ty!(felt[5])); assert_subtype!(ty!(felt[5]); !ty!(bool[5])); assert_subtype!(ty!(felt[5]); !ty!(uint[5])); @@ -772,12 +775,12 @@ mod tests { assert_subtype!(ty!(felt[5]); !ty!(bool[3, 4])); assert_subtype!(ty!(felt[5]); !ty!(uint[3, 4])); - assert_subtype!(ty!(bool[5]); !ty!(?)); + assert_subtype!(ty!(bool[5]); ty!(?)); assert_subtype!(ty!(bool[5]); !ty!(_)); assert_subtype!(ty!(bool[5]); !ty!(felt)); assert_subtype!(ty!(bool[5]); !ty!(bool)); assert_subtype!(ty!(bool[5]); !ty!(uint)); - assert_subtype!(ty!(bool[5]); !ty!(_[5])); + assert_subtype!(ty!(bool[5]); ty!(_[5])); assert_subtype!(ty!(bool[5]); ty!(felt[5])); assert_subtype!(ty!(bool[5]); ty!(bool[5])); assert_subtype!(ty!(bool[5]); !ty!(uint[5])); @@ -786,12 +789,12 @@ mod tests { assert_subtype!(ty!(bool[5]); !ty!(bool[3, 4])); assert_subtype!(ty!(bool[5]); !ty!(uint[3, 4])); - assert_subtype!(ty!(uint[5]); !ty!(?)); + assert_subtype!(ty!(uint[5]); ty!(?)); assert_subtype!(ty!(uint[5]); !ty!(_)); assert_subtype!(ty!(uint[5]); !ty!(felt)); assert_subtype!(ty!(uint[5]); !ty!(bool)); assert_subtype!(ty!(uint[5]); !ty!(uint)); - assert_subtype!(ty!(uint[5]); !ty!(_[5])); + assert_subtype!(ty!(uint[5]); ty!(_[5])); assert_subtype!(ty!(uint[5]); ty!(felt[5])); assert_subtype!(ty!(uint[5]); !ty!(bool[5])); assert_subtype!(ty!(uint[5]); ty!(uint[5])); @@ -800,7 +803,7 @@ mod tests { assert_subtype!(ty!(uint[5]); !ty!(bool[3, 4])); assert_subtype!(ty!(uint[5]); !ty!(uint[3, 4])); - assert_subtype!(ty!(_[3, 4]); !ty!(?)); + assert_subtype!(ty!(_[3, 4]); ty!(?)); assert_subtype!(ty!(_[3, 4]); !ty!(_)); assert_subtype!(ty!(_[3, 4]); !ty!(felt)); assert_subtype!(ty!(_[3, 4]); !ty!(bool)); @@ -810,11 +813,11 @@ mod tests { assert_subtype!(ty!(_[3, 4]); !ty!(bool[5])); assert_subtype!(ty!(_[3, 4]); !ty!(uint[5])); assert_subtype!(ty!(_[3, 4]); ty!(_[3, 4])); - assert_subtype!(ty!(_[3, 4]); ty!(felt[3, 4])); - assert_subtype!(ty!(_[3, 4]); ty!(bool[3, 4])); - assert_subtype!(ty!(_[3, 4]); ty!(uint[3, 4])); + assert_subtype!(ty!(_[3, 4]); !ty!(felt[3, 4])); + assert_subtype!(ty!(_[3, 4]); !ty!(bool[3, 4])); + assert_subtype!(ty!(_[3, 4]); !ty!(uint[3, 4])); - assert_subtype!(ty!(felt[3, 4]); !ty!(?)); + assert_subtype!(ty!(felt[3, 4]); ty!(?)); assert_subtype!(ty!(felt[3, 4]); !ty!(_)); assert_subtype!(ty!(felt[3, 4]); !ty!(felt)); assert_subtype!(ty!(felt[3, 4]); !ty!(bool)); @@ -823,12 +826,12 @@ mod tests { assert_subtype!(ty!(felt[3, 4]); !ty!(felt[5])); assert_subtype!(ty!(felt[3, 4]); !ty!(bool[5])); assert_subtype!(ty!(felt[3, 4]); !ty!(uint[5])); - assert_subtype!(ty!(felt[3, 4]); !ty!(_[3, 4])); + assert_subtype!(ty!(felt[3, 4]); ty!(_[3, 4])); assert_subtype!(ty!(felt[3, 4]); ty!(felt[3, 4])); assert_subtype!(ty!(felt[3, 4]); !ty!(bool[3, 4])); assert_subtype!(ty!(felt[3, 4]); !ty!(uint[3, 4])); - assert_subtype!(ty!(bool[3, 4]); !ty!(?)); + assert_subtype!(ty!(bool[3, 4]); ty!(?)); assert_subtype!(ty!(bool[3, 4]); !ty!(_)); assert_subtype!(ty!(bool[3, 4]); !ty!(felt)); assert_subtype!(ty!(bool[3, 4]); !ty!(bool)); @@ -837,12 +840,12 @@ mod tests { assert_subtype!(ty!(bool[3, 4]); !ty!(felt[5])); assert_subtype!(ty!(bool[3, 4]); !ty!(bool[5])); assert_subtype!(ty!(bool[3, 4]); !ty!(uint[5])); - assert_subtype!(ty!(bool[3, 4]); !ty!(_[3, 4])); + assert_subtype!(ty!(bool[3, 4]); ty!(_[3, 4])); assert_subtype!(ty!(bool[3, 4]); ty!(felt[3, 4])); assert_subtype!(ty!(bool[3, 4]); ty!(bool[3, 4])); assert_subtype!(ty!(bool[3, 4]); !ty!(uint[3, 4])); - assert_subtype!(ty!(uint[3, 4]); !ty!(?)); + assert_subtype!(ty!(uint[3, 4]); ty!(?)); assert_subtype!(ty!(uint[3, 4]); !ty!(_)); assert_subtype!(ty!(uint[3, 4]); !ty!(felt)); assert_subtype!(ty!(uint[3, 4]); !ty!(bool)); @@ -851,7 +854,7 @@ mod tests { assert_subtype!(ty!(uint[3, 4]); !ty!(felt[5])); assert_subtype!(ty!(uint[3, 4]); !ty!(bool[5])); assert_subtype!(ty!(uint[3, 4]); !ty!(uint[5])); - assert_subtype!(ty!(uint[3, 4]); !ty!(_[3, 4])); + assert_subtype!(ty!(uint[3, 4]); ty!(_[3, 4])); assert_subtype!(ty!(uint[3, 4]); ty!(felt[3, 4])); assert_subtype!(ty!(uint[3, 4]); !ty!(bool[3, 4])); assert_subtype!(ty!(uint[3, 4]); ty!(uint[3, 4])); From 80e56093ebcc31f2f8b1b82b3663ebea871daa10 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 5 Aug 2025 17:00:55 +0200 Subject: [PATCH 19/42] fix(typing): fix trace_type macro when len == 1 --- typing/src/types.rs | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/typing/src/types.rs b/typing/src/types.rs index 0da36a1a9..a7ea79a87 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -164,10 +164,13 @@ macro_rules! tty { ]) }; ($name:ident[$len:expr]) => { - $crate::ty!(felt[$len]) + match $len { + 1 => $crate::ty!(felt), + _ => $crate::ty!(felt[$len]), + } }; ($name:ident) => { - $crate::ty!(felt[1]) + $crate::ty!(felt) }; } @@ -738,28 +741,22 @@ mod tests { #[test] fn test_macro_trace_segment_type() { - assert_eq!(tty!(a), ty!(felt[1])); + assert_eq!(tty!(a), ty!(felt)); assert_eq!(tty!(a[5]), ty!(felt[5])); assert_eq!(tty!([]), Vec::>::new()); - assert_eq!(tty!([a]), vec![ty!(felt[1])]); + assert_eq!(tty!([a]), vec![ty!(felt)]); assert_eq!(tty!([a[5]]), vec![ty!(felt[5])]); - assert_eq!(tty!([a[1], b[3]]), vec![ty!(felt[1]), ty!(felt[3])]); + assert_eq!(tty!([a[1], b[3]]), vec![ty!(felt), ty!(felt[3])]); } #[test] fn test_macro_function_type() { assert_eq!(fty!(ev([])), FunctionType::Evaluator(vec![])); - assert_eq!(fty!(ev([a])), FunctionType::Evaluator(vec![ty!(felt[1])])); + assert_eq!(fty!(ev([a])), FunctionType::Evaluator(vec![ty!(felt)])); assert_eq!(fty!(ev([a[5]])), FunctionType::Evaluator(vec![ty!(felt[5])])); - assert_eq!(fty!(ev([a, b[3]])), FunctionType::Evaluator(vec![ty!(felt[1]), ty!(felt[3])])); - assert_eq!( - fty!(ev([a[1], b[3]])), - FunctionType::Evaluator(vec![ty!(felt[1]), ty!(felt[3])]) - ); - assert_eq!( - fty!(ev([a[1], b[3]])), - FunctionType::Evaluator(vec![ty!(felt[1]), ty!(felt[3])]) - ); + assert_eq!(fty!(ev([a, b[3]])), FunctionType::Evaluator(vec![ty!(felt), ty!(felt[3])])); + assert_eq!(fty!(ev([a[1], b[3]])), FunctionType::Evaluator(vec![ty!(felt), ty!(felt[3])])); + assert_eq!(fty!(ev([a[1], b[3]])), FunctionType::Evaluator(vec![ty!(felt), ty!(felt[3])])); assert_eq!(fty!(fn(uint) -> felt), FunctionType::Function(vec![ty!(uint)], ty!(felt))); assert_eq!( From 99d9403ea344a74e58baad0140ab3ca6e0f171f7 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 5 Aug 2025 17:07:50 +0200 Subject: [PATCH 20/42] Revert "fix(typing): fix access::Default on TraceBinding" This reverts commit ee8b4a7e6ed41f68370159e0e7c4192d21393b17. --- parser/src/ast/trace.rs | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index 5a70ef2bd..c5560a328 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -258,11 +258,7 @@ impl Access for TraceBinding { /// Derive a new [TraceBinding] derived from the current one given an [AccessType] fn access(&self, access_type: AccessType) -> Result { match access_type { - // - AccessType::Default => match self.size { - 1 => Ok(Self { ty: ty!(felt).unwrap(), ..*self }), - _ => Ok(*self), - }, + AccessType::Default => Ok(*self), AccessType::Slice(_) if self.is_scalar() => Err(InvalidAccessError::SliceOfScalar), AccessType::Slice(range) => { let slice_range = range.to_slice_range(); From 72c2eba4d84c9738613415fee3ff2fbd9edf6fda Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 5 Aug 2025 17:12:10 +0200 Subject: [PATCH 21/42] chores(typing): make lint --- typing/src/lib.rs | 1 - 1 file changed, 1 deletion(-) diff --git a/typing/src/lib.rs b/typing/src/lib.rs index a4e99b544..1083d7b9a 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -141,7 +141,6 @@ pub trait Typing { /// _ || y | y | n | n /// _[l] || y | n | y | n /// _[r,c] || y | n | n | y - /// fn is_subshape(&self, other: &impl Typing) -> bool { match (self.ty(), other.ty()) { (_, None) => true, From 4f3e939e8c792a5b291db9e566968e27e8e3dec4 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 6 Aug 2025 15:42:09 +0200 Subject: [PATCH 22/42] fix(mir): fix type-mismatch check on ListComprehension in translate allow widening of result_type to lowest common supertype of iterables --- parser/src/sema/semantic_analysis.rs | 37 ++++++++++++++++++---------- 1 file changed, 24 insertions(+), 13 deletions(-) diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index e882b2267..24ad4c5bb 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -563,19 +563,30 @@ impl VisitMut for SemanticAnalysis<'_> { let iterable = &expr.iterables[i]; let iterable_ty = iterable.ty().unwrap(); - if let Some(expected_ty) = result_ty.replace(iterable_ty) { - if expected_ty != iterable_ty { - self.has_type_errors = true; - // Note: We don't break here but at the end of the module's compilation, as we - // want to continue to gather as many errors as possible - let _ = self.type_mismatch( - Some(&iterable_ty), - iterable.span(), - &expected_ty, - expr.iterables[0].span(), - expr.span(), - ); - } + let lowest_common_supertype = if result_ty.is_some() { + result_ty.lowest_common_supertype(&iterable_ty) + } else { + // If the result type is None, then we use the iterable type as the default + // This means that either: + // - we encountered an error previously, + // - or this is the first iterable we are processing + Some(iterable_ty) + }; + if lowest_common_supertype.is_none() { + // If the lowest common supertype is None, and the result type is Some, + // then the types are incompatible + self.has_type_errors = true; + // Note: We don't break here but at the end of the module's compilation, as we + // want to continue to gather as many errors as possible + let _ = self.type_mismatch( + result_ty.as_ref(), + iterable.span(), + &iterable_ty, + expr.iterables[0].span(), + expr.span(), + ); + } else { + result_ty = lowest_common_supertype; } match self.expr_binding_type(iterable) { Ok(iterable_binding_ty) => { From c60f25fe8a0844e8042553dcf9439302467d895a Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 6 Aug 2025 15:47:45 +0200 Subject: [PATCH 23/42] feat(mir): better error reporting for translate on bin exprs + ? case --- parser/src/sema/semantic_analysis.rs | 27 +++++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 24ad4c5bb..fbcae848e 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -743,6 +743,25 @@ impl VisitMut for SemanticAnalysis<'_> { // Validate the operand types match expr.bin_ty.infer_ty() { + Ok(None) => { + self.has_type_errors = true; + // Note: We don't break here but at the end of the module's compilation, as we + // want to continue to gather as many errors as possible + self.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid binary expression") + .with_primary_label(expr.span(), "unable to infer type for binary expression") + .with_secondary_label( + expr.lhs.span(), + format!("this expression has type: {}", expr.lhs.show_ty()), + ) + .with_secondary_label( + expr.rhs.span(), + format!("this expression has type: {}", expr.rhs.show_ty()), + ) + .emit(); + ControlFlow::Continue(()) + }, Err(err) => { self.has_type_errors = true; // Note: We don't break here but at the end of the module's compilation, as we @@ -751,6 +770,14 @@ impl VisitMut for SemanticAnalysis<'_> { .diagnostic(Severity::Error) .with_message("invalid binary expression") .with_primary_label(expr.span(), format!("{err}")) + .with_secondary_label( + expr.lhs.span(), + format!("this expression has type: {}", expr.lhs.show_ty()), + ) + .with_secondary_label( + expr.rhs.span(), + format!("this expression has type: {}", expr.rhs.show_ty()), + ) .emit(); ControlFlow::Continue(()) }, From 636c77af0cbd05ed05238c2bcdb253021a6c3bb9 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 6 Aug 2025 16:59:16 +0200 Subject: [PATCH 24/42] fix(parser): fix BindingType.ty() + .kind() + sema's call handling --- parser/src/sema/binding_type.rs | 19 ++++++++++++++++++- parser/src/sema/semantic_analysis.rs | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index 06dcd9d5d..016f1c2b0 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -40,6 +40,23 @@ pub enum BindingType { } impl Typing for BindingType { + fn kind(&self) -> Option { + match self { + Self::Alias(aliased) => aliased.kind(), + Self::Local(ty) => ty.kind(), + Self::Constant(ty) => ty.kind(), + Self::Function(func) => func.kind(), + Self::Evaluator(ev) => { + Some(Kind::Callable(FunctionType::Evaluator(ev.iter().map(|tb| tb.ty()).collect()))) + }, + Self::Bus(_) => self.ty().kind(), + Self::TraceColumn(tb) | Self::TraceParam(tb) => tb.kind(), + Self::Vector(elems) => elems.kind(), + Self::PublicInput(ty) => ty.kind(), + // NOTE: this may need to be felt? + Self::PeriodicColumn(_) => Some(kind!(bool)), + } + } /// Get the value type of this binding, if applicable fn ty(&self) -> Option { match self { @@ -48,7 +65,7 @@ impl Typing for BindingType { Self::Alias(aliased) => aliased.ty(), Self::Local(ty) | Self::Constant(ty) | Self::PublicInput(ty) => Some(*ty), Self::PeriodicColumn(_) => ty!(felt), - Self::Function(ty) => ty.result(), + Self::Function(_) => None, Self::Evaluator(_) => None, Self::Bus(_) => ty!(felt), } diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index fbcae848e..3d4da28ad 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -670,7 +670,7 @@ impl VisitMut for SemanticAnalysis<'_> { match callee_binding_ty { Ok(ref binding_ty) => { let derived_from = binding_ty.span(); - if let BindingType::Function(ref fty) = binding_ty.item { + if let Some(Kind::Callable(ref fty)) = binding_ty.item.kind() { // There must be an evaluator by this name let qid = expr.callee.resolved().unwrap(); // Builtin functions are ignored here From 5e0cf1ce382b26d5e2e248462e17ea4f20f3fa07 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 7 Aug 2025 09:56:56 +0200 Subject: [PATCH 25/42] fix(typing): re-infer BinExpr bin_ty after updating its arguments --- air/src/tests/trace.rs | 2 +- mir/src/tests/trace.rs | 2 +- parser/src/ast/expression.rs | 46 ++++++++++++++++------------ parser/src/sema/semantic_analysis.rs | 2 +- 4 files changed, 29 insertions(+), 23 deletions(-) diff --git a/air/src/tests/trace.rs b/air/src/tests/trace.rs index 0fb0819bd..710675be7 100644 --- a/air/src/tests/trace.rs +++ b/air/src/tests/trace.rs @@ -151,5 +151,5 @@ fn err_ic_trace_cols_group_used_as_scalar() { enf a[0]' = a + clk; }"; - expect_diagnostic(source, "type mismatch"); + expect_diagnostic(source, "invalid binary expression"); } diff --git a/mir/src/tests/trace.rs b/mir/src/tests/trace.rs index 0fb0819bd..710675be7 100644 --- a/mir/src/tests/trace.rs +++ b/mir/src/tests/trace.rs @@ -151,5 +151,5 @@ fn err_ic_trace_cols_group_used_as_scalar() { enf a[0]' = a + clk; }"; - expect_diagnostic(source, "type mismatch"); + expect_diagnostic(source, "invalid binary expression"); } diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index c88b36593..751a98e39 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -772,32 +772,38 @@ pub struct BinaryExpr { pub op: BinaryOp, pub lhs: Box, pub rhs: Box, - pub bin_ty: BinType, + pub bin_ty: Option, } impl BinaryExpr { pub fn new(span: SourceSpan, op: BinaryOp, lhs: ScalarExpr, rhs: ScalarExpr) -> Self { - debug_assert!( - lhs.ty().is_none() || rhs.ty().is_none() || (lhs.is_scalar() && rhs.is_scalar()), - "binary expression operands must both be scalars, got: {} and {}", - lhs.show_ty(), - rhs.show_ty(), - ); - let l_ty = lhs.scalar_ty(); - let r_ty = rhs.scalar_ty(); - let bin_ty = match op { - BinaryOp::Eq => bty!(l_ty = r_ty), - BinaryOp::Add => bty!(l_ty + r_ty), - BinaryOp::Sub => bty!(l_ty - r_ty), - BinaryOp::Mul => bty!(l_ty * r_ty), - BinaryOp::Exp => bty!(l_ty ^ r_ty), - }; - Self { + let mut res = Self { span, op, lhs: Box::new(lhs), rhs: Box::new(rhs), - bin_ty, + bin_ty: None, + }; + res.update_bin_ty(); + res + } + pub fn update_bin_ty(&mut self) -> Option { + let lhs = self.lhs.as_ref(); + let rhs = self.rhs.as_ref(); + let op = self.op; + if !(lhs.ty().is_some() || rhs.ty().is_some() || (lhs.is_scalar() && rhs.is_scalar())) { + return None; } + let l_ty = lhs.ty(); + let r_ty = rhs.ty(); + let bin_ty = Some(match op { + BinaryOp::Eq => bty!(any:l_ty = any:r_ty), + BinaryOp::Add => bty!(any:l_ty + any:r_ty), + BinaryOp::Sub => bty!(any:l_ty - any:r_ty), + BinaryOp::Mul => bty!(any:l_ty * any:r_ty), + BinaryOp::Exp => bty!(any:l_ty ^ any:r_ty), + }); + self.bin_ty = bin_ty; + bin_ty } /// Returns true if this binary expression could expand to a block, e.g. due to a function call @@ -837,12 +843,12 @@ impl Typing for BinaryExpr { } impl ScalarTypeMut for BinaryExpr { fn scalar_ty_mut(&mut self) -> &mut Option { - self.bin_ty.scalar_ty_mut() + self.bin_ty.as_mut().unwrap().scalar_ty_mut() } } impl TypeMut for BinaryExpr { fn ty_mut(&mut self) -> &mut Option { - self.bin_ty.ty_mut() + self.bin_ty.as_mut().unwrap().ty_mut() } } diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 3d4da28ad..d5b8aea7f 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -740,7 +740,7 @@ impl VisitMut for SemanticAnalysis<'_> { ) -> ControlFlow { self.visit_mut_scalar_expr(expr.lhs.as_mut())?; self.visit_mut_scalar_expr(expr.rhs.as_mut())?; - + let _ = expr.update_bin_ty(); // Validate the operand types match expr.bin_ty.infer_ty() { Ok(None) => { From a37b3d3779eb8f739f73e7b8268d67423766d965 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 7 Aug 2025 11:22:30 +0200 Subject: [PATCH 26/42] fix(typing): sema: report diagnostic on non-constant exponents --- parser/src/ast/expression.rs | 6 +++++ parser/src/sema/semantic_analysis.rs | 2 +- typing/src/lib.rs | 8 ++++++ typing/src/types.rs | 37 ++++++++++++++-------------- 4 files changed, 33 insertions(+), 20 deletions(-) diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index 751a98e39..44a0bc55b 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -834,6 +834,12 @@ impl fmt::Display for BinaryExpr { } } impl Typing for BinaryExpr { + fn infer_ty(&self) -> Result, TypeError> { + match self.bin_ty { + Some(ref bty) => bty.infer_ty(), + None => Ok(None), + } + } fn scalar_ty(&self) -> Option { self.bin_ty.scalar_ty() } diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index d5b8aea7f..9e82fcad5 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -742,7 +742,7 @@ impl VisitMut for SemanticAnalysis<'_> { self.visit_mut_scalar_expr(expr.rhs.as_mut())?; let _ = expr.update_bin_ty(); // Validate the operand types - match expr.bin_ty.infer_ty() { + match expr.infer_ty() { Ok(None) => { self.has_type_errors = true; // Note: We don't break here but at the end of the module's compilation, as we diff --git a/typing/src/lib.rs b/typing/src/lib.rs index 1083d7b9a..34b718829 100644 --- a/typing/src/lib.rs +++ b/typing/src/lib.rs @@ -36,6 +36,10 @@ pub enum TypeError { bin_ty: BinType, span: Option, }, + NonConstantExponent { + bin_ty: BinType, + span: Option, + }, } impl core::fmt::Display for TypeError { @@ -65,6 +69,10 @@ impl core::fmt::Display for TypeError { write!(f, "incompatible types for binary operation: {}", bin_ty.show_fn_ty())?; Ok(()) }, + TypeError::NonConstantExponent { bin_ty, .. } => { + write!(f, "expected exponent to be a constant, got: {}", bin_ty.show_fn_ty())?; + Ok(()) + }, } } } diff --git a/typing/src/types.rs b/typing/src/types.rs index a7ea79a87..1c78aecdf 100644 --- a/typing/src/types.rs +++ b/typing/src/types.rs @@ -620,28 +620,25 @@ impl BinType { /// Returns the type of the result of an exponentiation based on the types /// of the left-hand side and right-hand side operands. - /// If lhs or rhs is not a scalar type or `?`, it returns a [TypeError::IncompatibleBinOp]. + /// If lhs is not a scalar type, or rhs is not `uint`, + /// it returns a [TypeError::IncompatibleBinOp]. /// /// based on the scalar types of the operands: /// ? ^ ? || felt | bool | uint | _ | ? /// =========||======|======|======|======|===== - /// felt || err | err | felt | _ | ? - /// bool || err | err | bool | _ | ? - /// uint || err | err | uint | _ | ? - /// _ || err | err | _ | _ | ? - /// ? || err | err | ? | ? | ? + /// felt || err | err | felt | err | err + /// bool || err | err | bool | err | err + /// uint || err | err | uint | err | err + /// _ || err | err | _ | err | err + /// ? || err | err | ? | err | err /// /// So, the result type of an exponentiation is: - /// - an error if either lhs or rhs is not a scalar type or `?`, - /// - an error if the rhs is not an uint or `?`, - /// - any ^ ? -> ?, - /// - ? ^ any -> ?, - /// - any ^ _ -> _, - /// - any:x ^ uint -> lhs, + /// - an error if either lhs or rhs isn't scalar types, + /// - an error if the rhs is not an uint + /// - the lhs type otherwise /// /// Because: - /// - it is an error if either lhs or rhs is not a scalar type or `?`, - /// - it is an error if rhs is not an uint or `?`, + /// - it is an error if rhs is not an uint /// - a bool to any power is still a bool: /// - 0^n = 0 /// - 1^n = 1 @@ -651,20 +648,22 @@ impl BinType { /// - a ? to any power is still a ? pub fn infer_bin_ty_exp(&self) -> Result, TypeError> { if let Some(ret) = self.result() { + eprintln!("infer_bin_ty_exp: returning cached result {ret:?}"); return Ok(Some(ret)); } let lhs = self.lhs(); let rhs = self.rhs(); + eprintln!("infer_bin_ty_exp: lhs = {lhs:?}, rhs = {rhs:?}"); if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { return Err(TypeError::IncompatibleBinOp { bin_ty: *self, span: None }); } + eprintln!(" MADE IT PAST THE SHAPE CHECK"); match self { - bty!(any ^ felt) | bty!(any ^ bool) => { - Err(TypeError::IncompatibleBinOp { bin_ty: *self, span: None }) + bty!(any ^ uint) => Ok(lhs), + bty!(any ^ felt) | bty!(any ^ bool) | bty!(any ^ _) | bty!(any ^ ?) => { + eprintln!(" ERROR: any ^ !uint"); + Err(TypeError::NonConstantExponent { bin_ty: *self, span: None }) }, - bty!(any ^ ?) | bty!(? ^ any) => Ok(ty!(?)), - bty!(any ^ _) => Ok(ty!(_)), - bty!(any:lhs ^ uint) => Ok(*lhs), _ => unreachable!("Undefined case for infer_bin_ty_exp: {self}"), } } From b2c75b65c88461fcf2a52afb19addf1cf0799d6a Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Fri, 8 Aug 2025 09:21:57 +0200 Subject: [PATCH 27/42] feat(typing): assert_bool primitive --- mir/src/passes/translate.rs | 22 ++++++++ mir/src/tests/mod.rs | 1 + mir/src/tests/typing.rs | 33 +++++++++++ parser/src/ast/expression.rs | 16 +++++- parser/src/sema/semantic_analysis.rs | 56 +++++++++++++++++++ parser/src/symbols.rs | 13 ++++- parser/src/transforms/constant_propagation.rs | 20 +++++++ 7 files changed, 157 insertions(+), 4 deletions(-) create mode 100644 mir/src/tests/typing.rs diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 64ac61832..b1e22d225 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -818,6 +818,28 @@ impl<'a> MirBuilder<'a> { .build(); Ok(node) }, + symbols::AssertBool => { + assert_eq!(call.args.len(), 1); + let x = self.translate_expr(call.args.first().unwrap())?; + // enf x^2 = x + let enforced = + Sub::builder() + .lhs(x.clone()) + .rhs( + Exp::builder() + .lhs(x) + .rhs(self.translate_const( + &ast::ConstantExpr::Scalar(2), + call.span(), + )?) + .span(call.span()) + .build(), + ) + .span(call.span()) + .build(); + let node = Enf::builder().span(call.span()).expr(enforced).build(); + Ok(node) + }, other => unimplemented!("unhandled builtin: {}", other), } } else { diff --git a/mir/src/tests/mod.rs b/mir/src/tests/mod.rs index e569a2d06..8008e87fb 100644 --- a/mir/src/tests/mod.rs +++ b/mir/src/tests/mod.rs @@ -12,6 +12,7 @@ mod pub_inputs; mod selectors; mod source_sections; mod trace; +mod typing; mod variables; use std::sync::Arc; diff --git a/mir/src/tests/typing.rs b/mir/src/tests/typing.rs new file mode 100644 index 000000000..7d4ff3bbd --- /dev/null +++ b/mir/src/tests/typing.rs @@ -0,0 +1,33 @@ +use super::compile; + +#[test] +fn test_typing() { + let code = " + def test + + trace_columns { + main: [a, b], + } + + public_inputs { + stack_inputs: [16], + } + + boundary_constraints { + enf a.first = 0; + } + + integrity_constraints { + let b2 = assert_bool(b); + let c = select(a, b2); + enf c = 42; + } + fn select(x: felt, selector: bool) -> felt { + return x * selector; + } + "; + let Ok(mir) = compile(code) else { + panic!("Failed to compile code: {}", code); + }; + dbg!(&mir.constraint_graph().integrity_constraints_roots); +} diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index 44a0bc55b..f4bb8a8a3 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -160,7 +160,7 @@ impl QualifiedIdentifier { if self.module.name() == "$builtin" { match self.item { NamespacedIdentifier::Function(id) => { - matches!(id.name(), symbols::Sum | symbols::Prod) + matches!(id.name(), symbols::Sum | symbols::Prod | symbols::AssertBool) }, _ => false, } @@ -1422,6 +1422,7 @@ impl Call { match callee.name() { symbols::Sum => Self::sum(span, args), symbols::Prod => Self::prod(span, args), + symbols::AssertBool => Self::assert_bool(span, args), _ => Self { span, callee: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Function(callee)), @@ -1451,6 +1452,19 @@ impl Call { Self::new_builtin(span, "prod", args, ty!(felt).unwrap()) } + /// Constructs a function call for `assert_bool`. + /// An `assert_bool(x)` is equivalent to an `enf x^2 = x plus a cast from felt to bool`. + #[inline] + pub fn assert_bool(span: SourceSpan, args: Vec) -> Self { + //Self::new_builtin(span, "assert_bool", args, ty!(felt).unwrap()) + let builtin_module = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin")); + let name = Identifier::new(span, Symbol::intern("assert_bool")); + let id = QualifiedIdentifier::new(builtin_module, NamespacedIdentifier::Function(name)); + let callee = ResolvableIdentifier::Resolved(id); + let ty = ty!(bool); + Self { span, callee, args, ty } + } + fn new_builtin(span: SourceSpan, name: &str, args: Vec, ty: Type) -> Self { let builtin_module = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin")); let name = Identifier::new(span, Symbol::intern(name)); diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index 9e82fcad5..eb539f154 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -1177,6 +1177,56 @@ impl SemanticAnalysis<'_> { }, } }, + // The known built-in cast functions - each takes a single argument, which + // must be a subtype of the expected type + symbols::AssertBool => { + match call.args.as_slice() { + [arg] => { + match self.expr_binding_type(arg) { + Ok(binding_ty) => { + if !binding_ty.ty().map(|t| t.is_scalar()).unwrap_or(false) { + self.has_type_errors = true; + self.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid call") + .with_primary_label( + call.span(), + "this function expects an argument of scalar type", + ) + .with_secondary_label( + arg.span(), + format!( + "but this argument is a {}", + binding_ty.show_kind() + ), + ) + .emit(); + } + }, + Err(e) => { + eprintln!("error: {e}"); + // We've already raised a diagnostic for this when visiting the + // access expression + assert!(self.has_undefined_variables || self.has_type_errors); + }, + } + }, + _ => { + self.has_type_errors = true; + self.diagnostics + .diagnostic(Severity::Error) + .with_message("invalid call") + .with_primary_label( + call.span(), + format!( + "the callee expects a single argument, but got {}", + call.args.len() + ), + ) + .emit(); + }, + } + }, other => unimplemented!("unrecognized builtin function: {}", other), } ControlFlow::Continue(()) @@ -1636,6 +1686,7 @@ impl SemanticAnalysis<'_> { None => { // If the call was resolved, it must be to an imported function, // and we will have already validated the reference + dbg!(&id); let (import_id, module_id) = self.imported.get_key_value(&id).unwrap(); let module = self.library.get(module_id).unwrap(); if !module.evaluators.contains_key(&id.id()) { @@ -1936,6 +1987,11 @@ impl SemanticAnalysis<'_> { let folder_ty = FunctionType::Function(vec![ty!(felt[usize::MAX])], ty!(felt)); Ok(Span::new(qid.span(), BindingType::Function(folder_ty))) }, + symbols::AssertBool => { + // An `assert_bool(x)` is equivalent to an `enf x^2 = x and + // a cast from felt to bool`. + Ok(Span::new(qid.span(), BindingType::Function(fty!(fn(felt) -> bool)))) + }, name => unimplemented!("unsupported builtin: {}", name), } } else { diff --git a/parser/src/symbols.rs b/parser/src/symbols.rs index 0b5022fc0..d917728f1 100644 --- a/parser/src/symbols.rs +++ b/parser/src/symbols.rs @@ -17,9 +17,16 @@ pub mod predefined { pub const Sum: Symbol = Symbol::new(2); /// The symbol `prod` pub const Prod: Symbol = Symbol::new(3); - - pub(super) const __SYMBOLS: &[(Symbol, &str)] = - &[(Main, "$main"), (Builtin, "$builtin"), (Sum, "sum"), (Prod, "prod")]; + /// The symbol `assert_bool` + pub const AssertBool: Symbol = Symbol::new(4); + + pub(super) const __SYMBOLS: &[(Symbol, &str)] = &[ + (Main, "$main"), + (Builtin, "$builtin"), + (Sum, "sum"), + (Prod, "prod"), + (AssertBool, "assert_bool"), + ]; } pub use self::predefined::*; diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index e6db48023..c5d3fceca 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -402,6 +402,26 @@ impl VisitMut for ConstantPropagation<'_> { } } }, + symbols::AssertBool => { + assert_eq!(call.args.len(), 1); + match &call.args[0] { + // If the assertion is a constant 0 or 1, it's valid + // TODO: if we start allowing casts from a uint to a bool, we should + // fold the assertion to a rebind of type bool if it is 0 or 1, + // and raise a diagnostic if it is not + Expr::Const(Span { item: ConstantExpr::Scalar(0 | 1), .. }) => {}, + // If the assertion is not 0 or 1, emit an error + Expr::Const(Span { item: ConstantExpr::Scalar(_), .. }) => { + self.diagnostics + .diagnostic(miden_diagnostics::Severity::Error) + .with_message("assertion failed") + .with_primary_label(span, "assertion failed") + .emit(); + return ControlFlow::Break(SemanticAnalysisError::Invalid); + }, + _ => {}, + } + }, invalid => unimplemented!("unknown builtin function: {invalid}"), } ControlFlow::Continue(()) From c8cfad323e5bbee1e6aaee44d404dda696d092f4 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Fri, 8 Aug 2025 18:00:15 +0200 Subject: [PATCH 28/42] fix(mir/translate): fix inserted enf expr --- mir/src/passes/translate.rs | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index b1e22d225..12585f14f 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -827,7 +827,7 @@ impl<'a> MirBuilder<'a> { .lhs(x.clone()) .rhs( Exp::builder() - .lhs(x) + .lhs(x.clone()) .rhs(self.translate_const( &ast::ConstantExpr::Scalar(2), call.span(), @@ -838,7 +838,11 @@ impl<'a> MirBuilder<'a> { .span(call.span()) .build(); let node = Enf::builder().span(call.span()).expr(enforced).build(); - Ok(node) + let _ = self.insert_enforce(node); + let bool_x = duplicate_node(x, &mut Default::default()); + // TODO: cast to a bool + //bool_x.update_ty(ty!(bool)); + Ok(bool_x) }, other => unimplemented!("unhandled builtin: {}", other), } From 9419d39622d547a4f100a09a277d39b20c63a068 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Tue, 26 Aug 2025 15:02:03 +0200 Subject: [PATCH 29/42] refactor(typing): rename typing crate to air_types --- Cargo.toml | 3 ++- mir/Cargo.toml | 2 +- mir/src/ir/mod.rs | 6 +++--- mir/src/ir/nodes/ops/mod.rs | 2 +- mir/src/ir/nodes/ops/parameter.rs | 2 +- mir/src/passes/translate.rs | 2 +- parser/Cargo.toml | 2 +- parser/src/ast/trace.rs | 2 +- parser/src/ast/types.rs | 2 +- parser/src/sema/binding_type.rs | 2 +- {typing => types}/Cargo.toml | 4 ++-- {typing => types}/src/lib.rs | 0 {typing => types}/src/types.rs | 0 13 files changed, 15 insertions(+), 14 deletions(-) rename {typing => types}/Cargo.toml (87%) rename {typing => types}/src/lib.rs (100%) rename {typing => types}/src/types.rs (100%) diff --git a/Cargo.toml b/Cargo.toml index 495ef6a7b..46c2f0794 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,7 +8,7 @@ members = [ "air", "codegen/winterfell", "codegen/ace", - "typing", + "types", ] resolver = "2" @@ -23,3 +23,4 @@ rust-version = "1.87" anyhow = "1.0" miden-diagnostics = "0.1" thiserror = "2.0" +air-types = { version = "1.0", path = "types" } diff --git a/mir/Cargo.toml b/mir/Cargo.toml index faf56c817..c1d8e0c43 100644 --- a/mir/Cargo.toml +++ b/mir/Cargo.toml @@ -14,7 +14,7 @@ edition.workspace = true [dependencies] air-parser = { package = "air-parser", path = "../parser", version = "0.5" } air-pass = { package = "air-pass", path = "../pass", version = "0.5" } -typing = { package = "typing", path = "../typing", version = "0.1" } +air-types.workspace = true anyhow = { workspace = true } derive-ir = { package = "air-derive-ir", path = "./derive-ir", version = "0.5" } miden-core = { package = "miden-core", version = "0.13", default-features = false } diff --git a/mir/src/ir/mod.rs b/mir/src/ir/mod.rs index 64cb5356a..70bb26582 100644 --- a/mir/src/ir/mod.rs +++ b/mir/src/ir/mod.rs @@ -7,9 +7,11 @@ mod nodes; mod owner; mod quad_eval; mod utils; +pub extern crate air_types; pub extern crate derive_ir; -pub extern crate typing; +#[allow(unused_imports)] +pub use air_types::*; pub use bus::Bus; pub use derive_ir::Builder; pub use graph::Graph; @@ -19,8 +21,6 @@ pub use node::Node; pub use nodes::*; pub use owner::Owner; pub use quad_eval::{QuadFelt, RandomInputs}; -#[allow(unused_imports)] -pub use typing::*; pub use utils::*; /// A trait for nodes that can have children /// This is used with the Child trait to allow for easy traversal and manipulation of the graph diff --git a/mir/src/ir/nodes/ops/mod.rs b/mir/src/ir/nodes/ops/mod.rs index 22a472075..9e0b48040 100644 --- a/mir/src/ir/nodes/ops/mod.rs +++ b/mir/src/ir/nodes/ops/mod.rs @@ -17,6 +17,7 @@ mod vector; pub use accessor::Accessor; pub use add::Add; +pub use air_types::*; pub use boundary::Boundary; pub use bus_op::{BusOp, BusOpKind}; pub use call::Call; @@ -29,7 +30,6 @@ pub use matrix::Matrix; pub use mul::Mul; pub use parameter::Parameter; pub use sub::Sub; -pub use typing::*; pub use value::{ BusAccess, ConstantValue, MirValue, PeriodicColumnAccess, PublicInputAccess, PublicInputTableAccess, SpannedMirValue, TraceAccess, TraceAccessBinding, Value, diff --git a/mir/src/ir/nodes/ops/parameter.rs b/mir/src/ir/nodes/ops/parameter.rs index 76e552fd5..6bc3eedae 100644 --- a/mir/src/ir/nodes/ops/parameter.rs +++ b/mir/src/ir/nodes/ops/parameter.rs @@ -1,7 +1,7 @@ use std::hash::{Hash, Hasher}; +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; -use typing::*; use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Singleton}; diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 12585f14f..3dd0a2406 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -3,8 +3,8 @@ use std::ops::Deref; use air_parser::{LexicalScope, ast, ast::AccessType, symbols}; use air_pass::Pass; +use air_types::*; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; -use typing::*; use crate::{ CompileError, diff --git a/parser/Cargo.toml b/parser/Cargo.toml index 29b8029db..6b0752713 100644 --- a/parser/Cargo.toml +++ b/parser/Cargo.toml @@ -16,7 +16,7 @@ lalrpop = { version = "0.20", default-features = false } [dependencies] air-pass = { package = "air-pass", path = "../pass", version = "0.5" } -typing = { package = "typing", path = "../typing", version = "0.1" } +air-types.workspace = true either = "1.12" lalrpop-util = "0.20" lazy_static = "1.4" diff --git a/parser/src/ast/trace.rs b/parser/src/ast/trace.rs index c5560a328..1260e9a61 100644 --- a/parser/src/ast/trace.rs +++ b/parser/src/ast/trace.rs @@ -1,7 +1,7 @@ use std::fmt; +use air_types::{Kind, Typing, tty, ty}; use miden_diagnostics::{SourceSpan, Spanned}; -use typing::{Kind, Typing, tty, ty}; use super::*; diff --git a/parser/src/ast/types.rs b/parser/src/ast/types.rs index f86653793..7a86fc1f5 100644 --- a/parser/src/ast/types.rs +++ b/parser/src/ast/types.rs @@ -1,4 +1,4 @@ -pub use typing::{bty, fty, kind, sty, tty, ty, tys, *}; +pub use air_types::{bty, fty, kind, sty, tty, ty, tys, *}; use super::*; diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index 016f1c2b0..a75c25f98 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -1,6 +1,6 @@ use std::fmt; -use typing::*; +use air_types::*; use crate::ast::{ Access, AccessType, BusType, FunctionType, InvalidAccessError, TraceBinding, TraceSegment, Type, diff --git a/typing/Cargo.toml b/types/Cargo.toml similarity index 87% rename from typing/Cargo.toml rename to types/Cargo.toml index 3095ca63a..e7b546375 100644 --- a/typing/Cargo.toml +++ b/types/Cargo.toml @@ -1,6 +1,6 @@ [package] -name = "typing" -version = "0.1.0" +name = "air-types" +version = "1.0.0" authors.workspace = true license.workspace = true repository.workspace = true diff --git a/typing/src/lib.rs b/types/src/lib.rs similarity index 100% rename from typing/src/lib.rs rename to types/src/lib.rs diff --git a/typing/src/types.rs b/types/src/types.rs similarity index 100% rename from typing/src/types.rs rename to types/src/types.rs From 3210f3602cba5365aaf776f97b7468e7a7edee44 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 11 Sep 2025 11:30:15 +0200 Subject: [PATCH 30/42] feat(types): FunctionType::check_args_kinds --- types/src/types.rs | 70 ++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 64 insertions(+), 6 deletions(-) diff --git a/types/src/types.rs b/types/src/types.rs index 1c78aecdf..88717874e 100644 --- a/types/src/types.rs +++ b/types/src/types.rs @@ -146,13 +146,13 @@ macro_rules! kinds { $res }; (RES: $res:expr; ?) => { - kinds!(RES: $crate::Push::push($res, $crate::kind!(?));) + kinds!(RES: $crate::Push::push($res, Option::Some(Box::new($crate::kind!(?))));) }; (RES: $res:expr; _$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { - kinds!(RES: $crate::Push::push($res, $crate::kind!(_$([$($spec)+])?)); $($($rest)+)?) + kinds!(RES: $crate::Push::push($res, Option::Some(Box::new($crate::kind!(_$([$($spec)+])?)))); $($($rest)+)?) }; (RES: $res:expr; $name:ident$([$($spec:tt)+])? $(, $($rest:tt)+)?) => { - kinds!(RES: $crate::Push::push($res, $crate::kind!($name$([$($spec)+])?)); $($($rest)+)?) + kinds!(RES: $crate::Push::push($res, Option::Some(Box::new($crate::kind!($name$([$($spec)+])?)))); $($($rest)+)?) }; } @@ -185,10 +185,10 @@ pub enum FunctionType { } impl FunctionType { - pub fn args(&self) -> &[Option] { + pub fn params(&self) -> &[Option] { match self { - Self::Evaluator(args) => args, - Self::Function(args, _) => args, + Self::Evaluator(params) => params, + Self::Function(params, _) => params, } } @@ -198,6 +198,23 @@ impl FunctionType { Self::Function(_, ret) => *ret, } } + + pub fn check_args_kinds(&self, args: &[&Kind]) -> bool { + eprintln!("Checking function type {} against params {:?}", self, args); + let params = self.params(); + if params.len() != args.len() { + return false; + } + for (arg_ty, param_kind) in args.iter().zip(params.iter()) { + eprintln!(" Checking arg_ty {arg_ty:?} against param_kind {param_kind:?}"); + if !arg_ty.is_subtype(param_kind) { + eprintln!(" Failed!: {arg_ty:?} is not a subtype of {param_kind:?}"); + return false; + } + } + eprintln!(" Success!"); + true + } } impl core::fmt::Display for FunctionType { @@ -796,4 +813,45 @@ mod tests { assert_eq!(kind!(_), Kind::Value(ty!(_))); assert_eq!(kind!(bool[3, 4]), Kind::Value(ty!(bool[3, 4]))); } + + #[test] + fn test_fn_ty_check_param_kinds() { + // Scalar types + assert!(fty!(fn(uint, felt) -> felt).check_args_kinds(&[&kind!(uint), &kind!(felt)]),); + assert!(fty!(fn(felt, felt) -> felt).check_args_kinds(&[&kind!(felt), &kind!(bool)]),); + // Vector types + assert!(fty!(fn(_[3], felt[2]) -> felt).check_args_kinds(&[&kind!(_[3]), &kind!(felt[2])])); + assert!( + fty!(fn(_[3], felt[2]) -> felt).check_args_kinds(&[&kind!(bool[3]), &kind!(uint[2])]) + ); + // Aggregate types + assert!(fty!(fn(felt[2], bool[3], uint[2]) -> felt).check_args_kinds(&[ + &kind!([bool, uint]), + &kind!([bool, bool, bool]), + &kind!([uint, uint]), + ])); + + // Negative cases + + // Scalar types + assert!(!fty!(fn(uint, bool) -> felt).check_args_kinds(&[&kind!(bool), &kind!(felt)]),); + assert!(!fty!(fn(felt, bool) -> felt).check_args_kinds(&[&kind!(felt), &kind!(uint[2])]),); + // Vector types + assert!(!fty!(fn(_[3], felt[2]) -> felt).check_args_kinds(&[&kind!(_), &kind!(felt[2])])); + assert!( + !fty!(fn(_[3], felt[2]) -> felt) + .check_args_kinds(&[&kind!(bool[3, 5]), &kind!(uint[2])]) + ); + // Aggregate types + assert!(!fty!(fn(felt[2], bool[3], uint[2]) -> felt).check_args_kinds(&[ + &kind!([bool, uint]), + &kind!([bool, felt, uint]), + &kind!([uint, uint]), + ])); + assert!(!fty!(fn(felt[2], bool[3], uint[2]) -> felt).check_args_kinds(&[ + &kind!([bool, uint]), + &kind!([bool, bool]), + &kind!([uint, uint]), + ])); + } } From 1ceaadc1db0107b204dbdf565e59d754b477a174 Mon Sep 17 00:00:00 2001 From: Leo-Besancon Date: Fri, 12 Sep 2025 10:31:19 +0200 Subject: [PATCH 31/42] feat: implement typing for MIR nodes --- mir/derive-ir/src/builder.rs | 12 +- mir/src/ir/bus.rs | 2 +- mir/src/ir/link.rs | 10 ++ mir/src/ir/mod.rs | 4 + mir/src/ir/node.rs | 12 +- mir/src/ir/nodes/mod.rs | 2 + mir/src/ir/nodes/none.rs | 33 ++++++ mir/src/ir/nodes/op.rs | 109 +++++++++++++++++- mir/src/ir/nodes/ops/accessor.rs | 43 ++++++- mir/src/ir/nodes/ops/add.rs | 40 ++++++- mir/src/ir/nodes/ops/boundary.rs | 32 ++++- mir/src/ir/nodes/ops/bus_op.rs | 26 ++++- mir/src/ir/nodes/ops/call.rs | 37 +++++- mir/src/ir/nodes/ops/enf.rs | 32 ++++- mir/src/ir/nodes/ops/exp.rs | 40 ++++++- mir/src/ir/nodes/ops/fold.rs | 24 +++- mir/src/ir/nodes/ops/for_op.rs | 24 +++- mir/src/ir/nodes/ops/if_op.rs | 24 +++- mir/src/ir/nodes/ops/matrix.rs | 35 +++++- mir/src/ir/nodes/ops/mod.rs | 1 - mir/src/ir/nodes/ops/mul.rs | 40 ++++++- mir/src/ir/nodes/ops/parameter.rs | 28 ++++- mir/src/ir/nodes/ops/sub.rs | 40 ++++++- mir/src/ir/nodes/ops/value.rs | 24 +++- mir/src/ir/nodes/ops/vector.rs | 35 +++++- mir/src/ir/nodes/root.rs | 35 ++++-- mir/src/ir/nodes/roots/evaluator.rs | 14 ++- mir/src/ir/nodes/roots/function.rs | 14 ++- mir/src/ir/owner.rs | 12 +- mir/src/passes/mod.rs | 2 +- mir/src/passes/translate.rs | 105 ++++++++++++++--- .../passes/unrolling/unrolling_first_pass.rs | 6 +- parser/src/sema/semantic_analysis.rs | 32 ++--- types/src/types.rs | 22 +++- 34 files changed, 838 insertions(+), 113 deletions(-) create mode 100644 mir/src/ir/nodes/none.rs diff --git a/mir/derive-ir/src/builder.rs b/mir/derive-ir/src/builder.rs index 92857403b..fd3b8398b 100644 --- a/mir/derive-ir/src/builder.rs +++ b/mir/derive-ir/src/builder.rs @@ -564,20 +564,24 @@ fn make_build_method( match enum_wrapper { EnumWrapper::Op => quote! { pub fn build(&self) -> crate::ir::Link { - Op::#name( + let mut op = Op::#name( #name { #(#fields),* } - ).into() + ); + op.finalize_hook(); + op.into() } }, EnumWrapper::Root => quote! { pub fn build(&self) -> crate::ir::Link { - Root::#name( + let mut root = Root::#name( #name { #(#fields),* } - ).into() + ); + root.finalize_hook(); + root.into() } }, } diff --git a/mir/src/ir/bus.rs b/mir/src/ir/bus.rs index 57822c7f3..81728f0dc 100644 --- a/mir/src/ir/bus.rs +++ b/mir/src/ir/bus.rs @@ -166,7 +166,7 @@ impl Link { for column in columns { bus_op = bus_op.args(column.clone()); } - let bus_op = bus_op.latch(latch.clone()).build(); + let bus_op = bus_op.latch(latch.clone()).ty(None).build(); self.borrow_mut().columns.push(bus_op.clone()); self.borrow_mut().latches.push(latch.clone()); bus_op diff --git a/mir/src/ir/link.rs b/mir/src/ir/link.rs index db1cab648..376d9f2df 100644 --- a/mir/src/ir/link.rs +++ b/mir/src/ir/link.rs @@ -5,6 +5,7 @@ use std::{ rc::{Rc, Weak}, }; +use air_types::Typing; use miden_diagnostics::{SourceSpan, Spanned}; /// A wrapper around a `Rc>` to allow custom trait implementations. @@ -110,6 +111,15 @@ where } } +impl Typing for Link +where + T: Typing, +{ + fn ty(&self) -> Option { + self.borrow().ty() + } +} + /// A wrapper around a `Option>>` to allow custom trait implementations. /// Used instead of `Link` where a `Link` would create a cyclIc reference. pub struct BackLink { diff --git a/mir/src/ir/mod.rs b/mir/src/ir/mod.rs index c15c7ddb5..05b4ba2ff 100644 --- a/mir/src/ir/mod.rs +++ b/mir/src/ir/mod.rs @@ -116,3 +116,7 @@ pub trait Builder { /// Create a new empty builder that exposes all fields fn builder() -> Self::Empty; } + +pub trait BuilderHook { + fn finalize_hook(&mut self) {} +} diff --git a/mir/src/ir/node.rs b/mir/src/ir/node.rs index 9bf47a696..30fe5925d 100644 --- a/mir/src/ir/node.rs +++ b/mir/src/ir/node.rs @@ -1,8 +1,8 @@ use std::ops::Deref; -use miden_diagnostics::{SourceSpan, Spanned}; +use miden_diagnostics::Spanned; -use crate::ir::{BackLink, Child, Link, Op, Owner, Parent, Root}; +use crate::ir::{BackLink, Child, Link, None, Op, Owner, Parent, Root}; /// All the nodes that can be in the MIR Graph /// Combines all [Root] and [Op] variants @@ -33,7 +33,7 @@ pub enum Node { BusOp(BackLink), Parameter(BackLink), Value(BackLink), - None(SourceSpan), + None(None), } impl Default for Node { @@ -220,17 +220,17 @@ impl Link { Op::BusOp(_) => Node::BusOp(BackLink::from(op_inner_val)), Op::Parameter(_) => Node::Parameter(BackLink::from(op_inner_val)), Op::Value(_) => Node::Value(BackLink::from(op_inner_val)), - Op::None(span) => Node::None(*span), + Op::None(none) => Node::None(none.clone()), }; } else if let Some(root_inner_val) = self.as_root() { to_update = match root_inner_val.clone().borrow().deref() { Root::Function(_) => Node::Function(BackLink::from(root_inner_val)), Root::Evaluator(_) => Node::Evaluator(BackLink::from(root_inner_val)), - Root::None(span) => Node::None(*span), + Root::None(none) => Node::None(none.clone()), }; } else { // If the [Node] is stale, we set it to None - to_update = Node::None(self.span()); + to_update = Node::None(None { span: self.span(), ty: None }); } *self.borrow_mut() = to_update; diff --git a/mir/src/ir/nodes/mod.rs b/mir/src/ir/nodes/mod.rs index 5ddf29733..ae26a94d3 100644 --- a/mir/src/ir/nodes/mod.rs +++ b/mir/src/ir/nodes/mod.rs @@ -1,3 +1,4 @@ +pub mod none; mod op; mod ops; mod root; @@ -5,6 +6,7 @@ mod roots; use std::cell::{Ref, RefMut}; +pub use none::*; pub use op::Op; pub use ops::*; pub use root::Root; diff --git a/mir/src/ir/nodes/none.rs b/mir/src/ir/nodes/none.rs new file mode 100644 index 000000000..640fe9dfe --- /dev/null +++ b/mir/src/ir/nodes/none.rs @@ -0,0 +1,33 @@ +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; +use miden_diagnostics::{SourceSpan, Spanned}; + +#[derive(Clone, PartialEq, Eq, Debug, Hash, Spanned)] +pub struct None { + #[span] + pub span: SourceSpan, + pub ty: Option, +} + +impl ScalarTypeMut for None { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.ty.scalar_ty_mut() + } +} + +impl TypeMut for None { + fn ty_mut(&mut self) -> &mut Option { + self.ty.ty_mut() + } +} + +impl Typing for None { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl Default for None { + fn default() -> Self { + None { span: Default::default(), ty: None } + } +} diff --git a/mir/src/ir/nodes/op.rs b/mir/src/ir/nodes/op.rs index 4d903dc5b..6b269f1dc 100644 --- a/mir/src/ir/nodes/op.rs +++ b/mir/src/ir/nodes/op.rs @@ -3,12 +3,13 @@ use std::{ ops::{Deref, DerefMut}, }; -use miden_diagnostics::{SourceSpan, Spanned}; +use air_types::{ScalarTypeMut, TypeMut, Typing}; +use miden_diagnostics::Spanned; use crate::ir::{ - Accessor, Add, BackLink, Boundary, BusOp, Call, Child, ConstantValue, Enf, Exp, Fold, For, If, - Link, Matrix, MirValue, Mul, Node, Owner, Parameter, Parent, Singleton, SpannedMirValue, Sub, - Value, Vector, get_inner, get_inner_mut, + Accessor, Add, BackLink, Boundary, BuilderHook, BusOp, Call, Child, ConstantValue, Enf, Exp, + Fold, For, If, Link, Matrix, MirValue, Mul, Node, None, Owner, Parameter, Parent, Singleton, + SpannedMirValue, Sub, Value, Vector, get_inner, get_inner_mut, }; /// The combined [Op]s and leaves of the MIR Graph. @@ -33,7 +34,31 @@ pub enum Op { BusOp(BusOp), Parameter(Parameter), Value(Value), - None(SourceSpan), + None(None), +} + +impl BuilderHook for Op { + fn finalize_hook(&mut self) { + match self { + Op::Enf(e) => e.finalize_hook(), + Op::Boundary(b) => b.finalize_hook(), + Op::Add(a) => a.finalize_hook(), + Op::Sub(s) => s.finalize_hook(), + Op::Mul(m) => m.finalize_hook(), + Op::Exp(e) => e.finalize_hook(), + Op::If(i) => i.finalize_hook(), + Op::For(f) => f.finalize_hook(), + Op::Call(c) => c.finalize_hook(), + Op::Fold(f) => f.finalize_hook(), + Op::Vector(v) => v.finalize_hook(), + Op::Matrix(m) => m.finalize_hook(), + Op::Accessor(a) => a.finalize_hook(), + Op::BusOp(b) => b.finalize_hook(), + Op::Parameter(p) => p.finalize_hook(), + Op::Value(v) => v.finalize_hook(), + Op::None(_) => {}, + } + } } impl Default for Op { @@ -134,6 +159,78 @@ impl Child for Op { } } +impl ScalarTypeMut for Op { + fn scalar_ty_mut(&mut self) -> &mut Option { + match self { + Op::Enf(e) => e.scalar_ty_mut(), + Op::Boundary(b) => b.scalar_ty_mut(), + Op::Add(a) => a.scalar_ty_mut(), + Op::Sub(s) => s.scalar_ty_mut(), + Op::Mul(m) => m.scalar_ty_mut(), + Op::Exp(e) => e.scalar_ty_mut(), + Op::If(i) => i.scalar_ty_mut(), + Op::For(f) => f.scalar_ty_mut(), + Op::Call(c) => c.scalar_ty_mut(), + Op::Fold(f) => f.scalar_ty_mut(), + Op::Vector(v) => v.scalar_ty_mut(), + Op::Matrix(m) => m.scalar_ty_mut(), + Op::Accessor(a) => a.scalar_ty_mut(), + Op::BusOp(b) => b.scalar_ty_mut(), + Op::Parameter(p) => p.scalar_ty_mut(), + Op::Value(v) => v.scalar_ty_mut(), + Op::None(n) => n.scalar_ty_mut(), + } + } +} + +impl TypeMut for Op { + fn ty_mut(&mut self) -> &mut Option { + match self { + Op::Enf(e) => e.ty_mut(), + Op::Boundary(b) => b.ty_mut(), + Op::Add(a) => a.ty_mut(), + Op::Sub(s) => s.ty_mut(), + Op::Mul(m) => m.ty_mut(), + Op::Exp(e) => e.ty_mut(), + Op::If(i) => i.ty_mut(), + Op::For(f) => f.ty_mut(), + Op::Call(c) => c.ty_mut(), + Op::Fold(f) => f.ty_mut(), + Op::Vector(v) => v.ty_mut(), + Op::Matrix(m) => m.ty_mut(), + Op::Accessor(a) => a.ty_mut(), + Op::BusOp(b) => b.ty_mut(), + Op::Parameter(p) => p.ty_mut(), + Op::Value(v) => v.ty_mut(), + Op::None(n) => n.ty_mut(), + } + } +} + +impl Typing for Op { + fn ty(&self) -> Option { + match self { + Op::Enf(e) => e.ty(), + Op::Boundary(b) => b.ty(), + Op::Add(a) => a.ty(), + Op::Sub(s) => s.ty(), + Op::Mul(m) => m.ty(), + Op::Exp(e) => e.ty(), + Op::If(i) => i.ty(), + Op::For(f) => f.ty(), + Op::Call(c) => c.ty(), + Op::Fold(f) => f.ty(), + Op::Vector(v) => v.ty(), + Op::Matrix(m) => m.ty(), + Op::Accessor(a) => a.ty(), + Op::BusOp(b) => b.ty(), + Op::Parameter(p) => p.ty(), + Op::Value(v) => v.ty(), + Op::None(n) => n.ty(), + } + } +} + impl Link { /// Debug the current [Op], showing [std::cell::RefCell]'s `@{pointer}` and inner struct. /// This is useful to debug shared mutability issues. @@ -385,7 +482,7 @@ impl Link { value._node = Singleton::from(node.clone()); node }, - Op::None(span) => Node::None(*span).into(), + Op::None(none) => Node::None(none.clone()).into(), } } /// Try getting the current [Op]'s [Owner] variant, diff --git a/mir/src/ir/nodes/ops/accessor.rs b/mir/src/ir/nodes/ops/accessor.rs index aea462cf4..1bbc2929d 100644 --- a/mir/src/ir/nodes/ops/accessor.rs +++ b/mir/src/ir/nodes/ops/accessor.rs @@ -1,9 +1,10 @@ use std::hash::Hash; -use air_parser::ast::AccessType; +use air_parser::ast::{Access, AccessType}; +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent accessing a given op, `indexable`, in two different ways: /// - access_type: AccessType, which describes for example how to access a given index for a Vector @@ -20,6 +21,35 @@ pub struct Accessor { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _ty: Option, +} + +impl ScalarTypeMut for Accessor { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._ty.scalar_ty_mut() + } +} + +impl TypeMut for Accessor { + fn ty_mut(&mut self) -> &mut Option { + self._ty.ty_mut() + } +} + +impl Typing for Accessor { + fn ty(&self) -> Option { + self._ty.ty() + } +} + +impl BuilderHook for Accessor { + fn finalize_hook(&mut self) { + self._ty = self + .indexable + .borrow() + .ty() + .map(|ty| ty.access(self.access_type.clone()).unwrap()); + } } impl Accessor { @@ -29,14 +59,15 @@ impl Accessor { offset: usize, span: SourceSpan, ) -> Link { - Op::Accessor(Self { - access_type, + let mut accessor = Self { indexable, + access_type, offset, span, ..Default::default() - }) - .into() + }; + accessor.finalize_hook(); + Op::Accessor(accessor).into() } } diff --git a/mir/src/ir/nodes/ops/add.rs b/mir/src/ir/nodes/ops/add.rs index e5bfe1ebf..e1b107851 100644 --- a/mir/src/ir/nodes/ops/add.rs +++ b/mir/src/ir/nodes/ops/add.rs @@ -1,6 +1,7 @@ +use air_types::{BinType, ScalarTypeMut, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent the addition of two MIR ops, `lhs` and `rhs` #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -13,11 +14,46 @@ pub struct Add { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _bin_ty: BinType, +} + +impl ScalarTypeMut for Add { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._bin_ty.scalar_ty_mut() + } +} + +impl TypeMut for Add { + fn ty_mut(&mut self) -> &mut Option { + self._bin_ty.ty_mut() + } +} + +impl Typing for Add { + fn ty(&self) -> Option { + self._bin_ty.ty() + } +} + +impl BuilderHook for Add { + fn finalize_hook(&mut self) { + self._bin_ty = BinType::Add(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); + let res = self._bin_ty.infer_bin_ty_add().unwrap(); + *self._bin_ty.result_mut() = res; + } } impl Add { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Add(Self { lhs, rhs, span, ..Default::default() }).into() + let mut add = Self { + lhs, + rhs, + span, + _bin_ty: BinType::default(), + ..Default::default() + }; + add.finalize_hook(); + Op::Add(add).into() } } diff --git a/mir/src/ir/nodes/ops/boundary.rs b/mir/src/ir/nodes/ops/boundary.rs index 8785fa018..1b8797278 100644 --- a/mir/src/ir/nodes/ops/boundary.rs +++ b/mir/src/ir/nodes/ops/boundary.rs @@ -1,9 +1,10 @@ use std::hash::Hash; use air_parser::ast::Boundary as BoundaryKind; +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent bounding a given op, `expr`, to access either the first or last row /// @@ -18,6 +19,31 @@ pub struct Boundary { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _ty: Option, +} + +impl ScalarTypeMut for Boundary { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._ty.scalar_ty_mut() + } +} + +impl TypeMut for Boundary { + fn ty_mut(&mut self) -> &mut Option { + self._ty.ty_mut() + } +} + +impl Typing for Boundary { + fn ty(&self) -> Option { + self._ty.ty() + } +} + +impl BuilderHook for Boundary { + fn finalize_hook(&mut self) { + self._ty = self.expr.borrow().ty(); + } } impl Hash for Boundary { @@ -32,7 +58,9 @@ impl Hash for Boundary { impl Boundary { pub fn create(expr: Link, kind: BoundaryKind, span: SourceSpan) -> Link { - Op::Boundary(Self { expr, kind, span, ..Default::default() }).into() + let mut boundary = Self { expr, kind, span, ..Default::default() }; + boundary.finalize_hook(); + Op::Boundary(boundary).into() } } diff --git a/mir/src/ir/nodes/ops/bus_op.rs b/mir/src/ir/nodes/ops/bus_op.rs index a868275ca..e57381597 100644 --- a/mir/src/ir/nodes/ops/bus_op.rs +++ b/mir/src/ir/nodes/ops/bus_op.rs @@ -1,8 +1,11 @@ use std::hash::Hash; +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{ + BackLink, Builder, BuilderHook, Bus, Child, Link, Node, Op, Owner, Parent, Singleton, +}; #[derive(Clone, Copy, PartialEq, Eq, Debug, Default, Hash)] pub enum BusOpKind { @@ -23,8 +26,29 @@ pub struct BusOp { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub ty: Option, } +impl ScalarTypeMut for BusOp { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.ty.scalar_ty_mut() + } +} + +impl TypeMut for BusOp { + fn ty_mut(&mut self) -> &mut Option { + self.ty.ty_mut() + } +} + +impl Typing for BusOp { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for BusOp {} + impl Hash for BusOp { fn hash(&self, state: &mut H) { self.bus.get_name().hash(state); diff --git a/mir/src/ir/nodes/ops/call.rs b/mir/src/ir/nodes/ops/call.rs index efa169328..1ee4537a3 100644 --- a/mir/src/ir/nodes/ops/call.rs +++ b/mir/src/ir/nodes/ops/call.rs @@ -1,6 +1,9 @@ +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Root, Singleton}; +use crate::ir::{ + BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Root, Singleton, +}; /// A MIR operation to represent a call to a given function, a `Root` that represents either a /// `Function` or an `Evaluator` @@ -21,17 +24,43 @@ pub struct Call { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _ty: Option, +} + +impl ScalarTypeMut for Call { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._ty.scalar_ty_mut() + } +} + +impl TypeMut for Call { + fn ty_mut(&mut self) -> &mut Option { + self._ty.ty_mut() + } +} + +impl Typing for Call { + fn ty(&self) -> Option { + self._ty.ty() + } +} + +impl BuilderHook for Call { + fn finalize_hook(&mut self) { + self._ty = self.function.borrow().ty(); + } } impl Call { pub fn create(function: Link, arguments: Vec>, span: SourceSpan) -> Link { - Op::Call(Self { + let mut call = Self { function, arguments: Link::new(arguments), span, ..Default::default() - }) - .into() + }; + call.finalize_hook(); + Op::Call(call).into() } } diff --git a/mir/src/ir/nodes/ops/enf.rs b/mir/src/ir/nodes/ops/enf.rs index 0206a4023..e5ecc0c3e 100644 --- a/mir/src/ir/nodes/ops/enf.rs +++ b/mir/src/ir/nodes/ops/enf.rs @@ -1,6 +1,7 @@ +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to enforce that a given MIR op, `expr` equals zero #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -12,11 +13,38 @@ pub struct Enf { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _ty: Option, +} + +impl ScalarTypeMut for Enf { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._ty.scalar_ty_mut() + } +} + +impl TypeMut for Enf { + fn ty_mut(&mut self) -> &mut Option { + self._ty.ty_mut() + } +} + +impl Typing for Enf { + fn ty(&self) -> Option { + self._ty.ty() + } +} + +impl BuilderHook for Enf { + fn finalize_hook(&mut self) { + self._ty = self.expr.borrow().ty(); + } } impl Enf { pub fn create(expr: Link, span: SourceSpan) -> Link { - Op::Enf(Self { expr, span, ..Default::default() }).into() + let mut enf = Self { expr, span, ..Default::default() }; + enf.finalize_hook(); + Op::Enf(enf).into() } } diff --git a/mir/src/ir/nodes/ops/exp.rs b/mir/src/ir/nodes/ops/exp.rs index c888d1cf6..3bda91cf9 100644 --- a/mir/src/ir/nodes/ops/exp.rs +++ b/mir/src/ir/nodes/ops/exp.rs @@ -1,6 +1,7 @@ +use air_types::{BinType, ScalarTypeMut, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent the exponentiation of a MIR op, `lhs` by another, `rhs` /// @@ -15,11 +16,46 @@ pub struct Exp { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _bin_ty: BinType, +} + +impl ScalarTypeMut for Exp { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._bin_ty.scalar_ty_mut() + } +} + +impl TypeMut for Exp { + fn ty_mut(&mut self) -> &mut Option { + self._bin_ty.ty_mut() + } +} + +impl Typing for Exp { + fn ty(&self) -> Option { + self._bin_ty.ty() + } +} + +impl BuilderHook for Exp { + fn finalize_hook(&mut self) { + self._bin_ty = BinType::Exp(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); + let res = self._bin_ty.infer_bin_ty_add().unwrap(); + *self._bin_ty.result_mut() = res; + } } impl Exp { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Exp(Self { lhs, rhs, span, ..Default::default() }).into() + let mut exp = Self { + lhs, + rhs, + span, + _bin_ty: BinType::default(), + ..Default::default() + }; + exp.finalize_hook(); + Op::Exp(exp).into() } } diff --git a/mir/src/ir/nodes/ops/fold.rs b/mir/src/ir/nodes/ops/fold.rs index 0ca0a7c9d..346639df6 100644 --- a/mir/src/ir/nodes/ops/fold.rs +++ b/mir/src/ir/nodes/ops/fold.rs @@ -1,6 +1,7 @@ +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent folding a given Vector operator according to a given operator and /// initial value @@ -21,8 +22,29 @@ pub struct Fold { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub ty: Option, } +impl ScalarTypeMut for Fold { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.ty.scalar_ty_mut() + } +} + +impl TypeMut for Fold { + fn ty_mut(&mut self) -> &mut Option { + self.ty.ty_mut() + } +} + +impl Typing for Fold { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for Fold {} + #[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] pub enum FoldOperator { Add, diff --git a/mir/src/ir/nodes/ops/for_op.rs b/mir/src/ir/nodes/ops/for_op.rs index 9ab9a200d..9ea8bdc7c 100644 --- a/mir/src/ir/nodes/ops/for_op.rs +++ b/mir/src/ir/nodes/ops/for_op.rs @@ -1,6 +1,7 @@ +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent list comprehensions. /// @@ -20,8 +21,29 @@ pub struct For { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub ty: Option, } +impl ScalarTypeMut for For { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.ty.scalar_ty_mut() + } +} + +impl TypeMut for For { + fn ty_mut(&mut self) -> &mut Option { + self.ty.ty_mut() + } +} + +impl Typing for For { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for For {} + impl For { pub fn create( iterators: Link>>, diff --git a/mir/src/ir/nodes/ops/if_op.rs b/mir/src/ir/nodes/ops/if_op.rs index 023f65634..64657f8ad 100644 --- a/mir/src/ir/nodes/ops/if_op.rs +++ b/mir/src/ir/nodes/ops/if_op.rs @@ -1,6 +1,7 @@ +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent conditional constraints /// @@ -18,8 +19,29 @@ pub struct If { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub ty: Option, } +impl ScalarTypeMut for If { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.ty.scalar_ty_mut() + } +} + +impl TypeMut for If { + fn ty_mut(&mut self) -> &mut Option { + self.ty.ty_mut() + } +} + +impl Typing for If { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for If {} + #[derive(Default, Clone, PartialEq, Eq, Debug, Hash)] pub struct MatchArm { pub condition: Link, diff --git a/mir/src/ir/nodes/ops/matrix.rs b/mir/src/ir/nodes/ops/matrix.rs index 277783ca4..d87cebacf 100644 --- a/mir/src/ir/nodes/ops/matrix.rs +++ b/mir/src/ir/nodes/ops/matrix.rs @@ -1,6 +1,7 @@ +use air_types::{Kind, ScalarTypeMut, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent a matrix of MIR ops of a given size #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -14,18 +15,44 @@ pub struct Matrix { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _kind: Option, +} + +impl ScalarTypeMut for Matrix { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._kind.as_mut().unwrap().scalar_ty_mut() + } +} + +impl TypeMut for Matrix { + fn ty_mut(&mut self) -> &mut Option { + self._kind.as_mut().unwrap().ty_mut() + } +} + +impl Typing for Matrix { + fn ty(&self) -> Option { + self._kind.ty() + } +} + +impl BuilderHook for Matrix { + fn finalize_hook(&mut self) { + self._kind = self.elements.borrow().kind(); + } } impl Matrix { pub fn create(elements: Vec>, span: SourceSpan) -> Link { let size = elements.len(); - Op::Matrix(Self { + let mut mat = Self { size, elements: Link::new(elements), span, ..Default::default() - }) - .into() + }; + mat.finalize_hook(); + Op::Matrix(mat).into() } } diff --git a/mir/src/ir/nodes/ops/mod.rs b/mir/src/ir/nodes/ops/mod.rs index 9e0b48040..5ead181af 100644 --- a/mir/src/ir/nodes/ops/mod.rs +++ b/mir/src/ir/nodes/ops/mod.rs @@ -17,7 +17,6 @@ mod vector; pub use accessor::Accessor; pub use add::Add; -pub use air_types::*; pub use boundary::Boundary; pub use bus_op::{BusOp, BusOpKind}; pub use call::Call; diff --git a/mir/src/ir/nodes/ops/mul.rs b/mir/src/ir/nodes/ops/mul.rs index 247123755..b4a678413 100644 --- a/mir/src/ir/nodes/ops/mul.rs +++ b/mir/src/ir/nodes/ops/mul.rs @@ -1,6 +1,7 @@ +use air_types::{BinType, ScalarTypeMut, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent the multiplication of two MIR ops, `lhs` and `rhs` #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -13,11 +14,46 @@ pub struct Mul { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _bin_ty: BinType, +} + +impl ScalarTypeMut for Mul { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._bin_ty.scalar_ty_mut() + } +} + +impl TypeMut for Mul { + fn ty_mut(&mut self) -> &mut Option { + self._bin_ty.ty_mut() + } +} + +impl Typing for Mul { + fn ty(&self) -> Option { + self._bin_ty.ty() + } +} + +impl BuilderHook for Mul { + fn finalize_hook(&mut self) { + self._bin_ty = BinType::Mul(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); + let res = self._bin_ty.infer_bin_ty_add().unwrap(); + *self._bin_ty.result_mut() = res; + } } impl Mul { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Mul(Self { lhs, rhs, span, ..Default::default() }).into() + let mut mul = Self { + lhs, + rhs, + span, + _bin_ty: BinType::default(), + ..Default::default() + }; + mul.finalize_hook(); + Op::Mul(mul).into() } } diff --git a/mir/src/ir/nodes/ops/parameter.rs b/mir/src/ir/nodes/ops/parameter.rs index 6bc3eedae..01b25ae85 100644 --- a/mir/src/ir/nodes/ops/parameter.rs +++ b/mir/src/ir/nodes/ops/parameter.rs @@ -1,9 +1,9 @@ use std::hash::{Hash, Hasher}; -use air_types::*; +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Singleton}; /// A MIR operation to represent a `Parameter` in a function or evaluator. /// Also used in If and For loops to represent declared parameters. @@ -15,13 +15,33 @@ pub struct Parameter { pub ref_node: BackLink, /// The position of the `Parameter` in the referred node's `Parameter` list pub position: usize, - /// The type of the `Parameter` - pub ty: Option, pub _node: Singleton, #[span] pub span: SourceSpan, + /// The type of the `Parameter` + pub ty: Option, +} + +impl ScalarTypeMut for Parameter { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.ty.scalar_ty_mut() + } } +impl TypeMut for Parameter { + fn ty_mut(&mut self) -> &mut Option { + self.ty.ty_mut() + } +} + +impl Typing for Parameter { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for Parameter {} + impl Parameter { pub fn create(position: usize, ty: Type, span: SourceSpan) -> Link { Op::Parameter(Self { diff --git a/mir/src/ir/nodes/ops/sub.rs b/mir/src/ir/nodes/ops/sub.rs index 94e2c00ad..5647a62e1 100644 --- a/mir/src/ir/nodes/ops/sub.rs +++ b/mir/src/ir/nodes/ops/sub.rs @@ -1,6 +1,7 @@ +use air_types::{BinType, ScalarTypeMut, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent the subtraction of two MIR ops, `lhs` and `rhs` #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -13,11 +14,46 @@ pub struct Sub { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _bin_ty: BinType, +} + +impl ScalarTypeMut for Sub { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._bin_ty.scalar_ty_mut() + } +} + +impl TypeMut for Sub { + fn ty_mut(&mut self) -> &mut Option { + self._bin_ty.ty_mut() + } +} + +impl Typing for Sub { + fn ty(&self) -> Option { + self._bin_ty.ty() + } +} + +impl BuilderHook for Sub { + fn finalize_hook(&mut self) { + self._bin_ty = BinType::Sub(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); + let res = self._bin_ty.infer_bin_ty_add().unwrap(); + *self._bin_ty.result_mut() = res; + } } impl Sub { pub fn create(lhs: Link, rhs: Link, span: SourceSpan) -> Link { - Op::Sub(Self { lhs, rhs, span, ..Default::default() }).into() + let mut sub = Self { + lhs, + rhs, + span, + _bin_ty: BinType::default(), + ..Default::default() + }; + sub.finalize_hook(); + Op::Sub(sub).into() } } diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index fea10e8a8..83be947e4 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -1,7 +1,8 @@ use air_parser::ast::{BusType, Identifier, QualifiedIdentifier, TraceColumnIndex, TraceSegmentId}; +use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Bus, Child, Link, Node, Op, Owner, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Bus, Child, Link, Node, Op, Owner, Singleton}; /// A MIR operation to represent a known value, [Value]. /// @@ -13,8 +14,29 @@ pub struct Value { #[span] pub value: SpannedMirValue, pub _node: Singleton, + pub ty: Option, } +impl ScalarTypeMut for Value { + fn scalar_ty_mut(&mut self) -> &mut Option { + self.ty.scalar_ty_mut() + } +} + +impl TypeMut for Value { + fn ty_mut(&mut self) -> &mut Option { + self.ty.ty_mut() + } +} + +impl Typing for Value { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl BuilderHook for Value {} + impl Value { pub fn create(value: SpannedMirValue) -> Link { Op::Value(Self { value, ..Default::default() }).into() diff --git a/mir/src/ir/nodes/ops/vector.rs b/mir/src/ir/nodes/ops/vector.rs index fc53c3d67..167b3b3e3 100644 --- a/mir/src/ir/nodes/ops/vector.rs +++ b/mir/src/ir/nodes/ops/vector.rs @@ -1,6 +1,7 @@ +use air_types::{Kind, ScalarTypeMut, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{BackLink, Builder, Child, Link, Node, Op, Owner, Parent, Singleton}; +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; /// A MIR operation to represent a vector of MIR ops of a given size #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -13,18 +14,44 @@ pub struct Vector { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub _kind: Option, +} + +impl ScalarTypeMut for Vector { + fn scalar_ty_mut(&mut self) -> &mut Option { + self._kind.as_mut().unwrap().scalar_ty_mut() + } +} + +impl TypeMut for Vector { + fn ty_mut(&mut self) -> &mut Option { + self._kind.as_mut().unwrap().ty_mut() + } +} + +impl Typing for Vector { + fn ty(&self) -> Option { + self._kind.ty() + } +} + +impl BuilderHook for Vector { + fn finalize_hook(&mut self) { + self._kind = self.elements.borrow().kind(); + } } impl Vector { pub fn create(elements: Vec>, span: SourceSpan) -> Link { let size = elements.len(); - Op::Vector(Self { + let mut vec = Self { size, elements: Link::new(elements), span, ..Default::default() - }) - .into() + }; + vec.finalize_hook(); + Op::Vector(vec).into() } } diff --git a/mir/src/ir/nodes/root.rs b/mir/src/ir/nodes/root.rs index 0c8a4c9d5..29f6ef04f 100644 --- a/mir/src/ir/nodes/root.rs +++ b/mir/src/ir/nodes/root.rs @@ -3,11 +3,12 @@ use std::{ ops::{Deref, DerefMut}, }; -use miden_diagnostics::{SourceSpan, Spanned}; +use air_types::Typing; +use miden_diagnostics::Spanned; use crate::ir::{ - BackLink, Evaluator, Function, Link, Node, Op, Owner, Parent, Singleton, get_inner, - get_inner_mut, + BackLink, BuilderHook, Evaluator, Function, Link, Node, None, Op, Owner, Parent, Singleton, + get_inner, get_inner_mut, }; /// The root nodes of the MIR Graph @@ -17,12 +18,22 @@ use crate::ir::{ pub enum Root { Function(Function), Evaluator(Evaluator), - None(SourceSpan), + None(None), +} + +impl BuilderHook for Root { + fn finalize_hook(&mut self) { + match self { + Root::Function(f) => f.finalize_hook(), + Root::Evaluator(e) => e.finalize_hook(), + Root::None(_) => {}, + } + } } impl Default for Root { fn default() -> Self { - Root::None(SourceSpan::default()) + Root::None(None::default()) } } @@ -37,6 +48,16 @@ impl Parent for Root { } } +impl Typing for Root { + fn ty(&self) -> Option { + match self { + Root::Function(f) => f.ty(), + Root::Evaluator(e) => e.ty(), + Root::None(n) => n.ty(), + } + } +} + impl Link { pub fn debug(&self) -> String { match self.borrow().deref() { @@ -70,7 +91,7 @@ impl Link { e._node = Singleton::from(node.clone()); node }, - Root::None(span) => Node::None(*span).into(), + Root::None(none) => Node::None(none.clone()).into(), } } /// Get the current [Root]'s [Owner] variant @@ -90,7 +111,7 @@ impl Link { e._owner = Singleton::from(owner.clone()); owner }, - Root::None(span) => Owner::None(*span).into(), + Root::None(none) => Owner::None(none.clone()).into(), } } /// Try getting the current [Root]'s inner [Function]. diff --git a/mir/src/ir/nodes/roots/evaluator.rs b/mir/src/ir/nodes/roots/evaluator.rs index 35e0fd6e7..67d1363e7 100644 --- a/mir/src/ir/nodes/roots/evaluator.rs +++ b/mir/src/ir/nodes/roots/evaluator.rs @@ -1,6 +1,7 @@ +use air_types::{FunctionType, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; +use crate::ir::{Builder, BuilderHook, Link, Node, Op, Owner, Parent, Root, Singleton}; /// A MIR Root to represent a Evaluator definition #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -15,18 +16,29 @@ pub struct Evaluator { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub func_ty: FunctionType, } +impl Typing for Evaluator { + fn ty(&self) -> Option { + self.func_ty.result() + } +} + +impl BuilderHook for Evaluator {} + impl Evaluator { pub fn create( parameters: Vec>>, body: Vec>, span: SourceSpan, + func_ty: FunctionType, ) -> Link { Root::Evaluator(Self { parameters, body: Link::new(body), span, + func_ty, ..Default::default() }) .into() diff --git a/mir/src/ir/nodes/roots/function.rs b/mir/src/ir/nodes/roots/function.rs index 12051f466..251075c23 100644 --- a/mir/src/ir/nodes/roots/function.rs +++ b/mir/src/ir/nodes/roots/function.rs @@ -1,6 +1,7 @@ +use air_types::{FunctionType, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; -use crate::ir::{Builder, Link, Node, Op, Owner, Parent, Root, Singleton}; +use crate::ir::{Builder, BuilderHook, Link, Node, Op, Owner, Parent, Root, Singleton}; /// A MIR Root to represent a Function definition #[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] @@ -16,20 +17,31 @@ pub struct Function { pub _owner: Singleton, #[span] pub span: SourceSpan, + pub func_ty: FunctionType, } +impl Typing for Function { + fn ty(&self) -> Option { + self.func_ty.result() + } +} + +impl BuilderHook for Function {} + impl Function { pub fn create( parameters: Vec>, return_type: Link, body: Vec>, span: SourceSpan, + func_ty: FunctionType, ) -> Link { Root::Function(Self { parameters, return_type, body: Link::new(body), span, + func_ty, ..Default::default() }) .into() diff --git a/mir/src/ir/owner.rs b/mir/src/ir/owner.rs index d1a2c513b..f2db53dd5 100644 --- a/mir/src/ir/owner.rs +++ b/mir/src/ir/owner.rs @@ -1,8 +1,8 @@ use std::ops::Deref; -use miden_diagnostics::{SourceSpan, Spanned}; +use miden_diagnostics::Spanned; -use crate::ir::{BackLink, Child, Link, Node, Op, Parent, Root}; +use crate::ir::{BackLink, Child, Link, Node, None, Op, Parent, Root}; /// The nodes that can own [Op] nodes /// The [Owner] enum does not own it's inner struct to avoid reference cycles, @@ -30,7 +30,7 @@ pub enum Owner { Enf(BackLink), For(BackLink), If(BackLink), - None(SourceSpan), + None(None), } impl Parent for Owner { @@ -199,17 +199,17 @@ impl Link { Op::BusOp(_) => Owner::BusOp(BackLink::from(op_inner_val)), Op::Parameter(_) => unreachable!(), Op::Value(_) => unreachable!(), - Op::None(span) => Owner::None(*span), + Op::None(none) => Owner::None(none.clone()), }; } else if let Some(root_inner_val) = self.as_root() { to_update = match root_inner_val.clone().borrow().deref() { Root::Function(_) => Owner::Function(BackLink::from(root_inner_val)), Root::Evaluator(_) => Owner::Evaluator(BackLink::from(root_inner_val)), - Root::None(span) => Owner::None(*span), + Root::None(none) => Owner::None(none.clone()), }; } else { // If the [Owner] is stale, we set it to None - to_update = Owner::None(self.span()); + to_update = Owner::None(None { span: self.span(), ty: None }); } *self.borrow_mut() = to_update; diff --git a/mir/src/passes/mod.rs b/mir/src/passes/mod.rs index 8bf2123d4..864b5f0ea 100644 --- a/mir/src/passes/mod.rs +++ b/mir/src/passes/mod.rs @@ -200,7 +200,7 @@ pub fn duplicate_node( new_param }, Op::Value(value) => Value::create(value.value.clone()), - Op::None(span) => Op::None(*span).into(), + Op::None(none) => Op::None(none.clone()).into(), } } diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 7670a0984..0f6eaaf30 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -16,7 +16,7 @@ use crate::{ Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, Root, - SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Type, Value, Vector, + SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Type, Value, Vector, none, }, passes::duplicate_node, }; @@ -147,6 +147,7 @@ impl<'a> MirBuilder<'a> { ast_eval: &'a ast::EvaluatorFunction, ) -> Result, CompileError> { let mut all_params_flatten = Vec::new(); + let mut all_params_ty_flatten = Vec::new(); self.root_name = Some(ident); let mut ev = Evaluator::builder().span(ast_eval.span); @@ -163,12 +164,13 @@ impl<'a> MirBuilder<'a> { for param in params { all_params_flatten_for_trace_segment.push(param.clone()); all_params_flatten.push(param.clone()); + all_params_ty_flatten.push(binding.ty()); } } ev = ev.parameters(all_params_flatten_for_trace_segment.clone()); } - let ev = ev.build(); + let ev = ev.func_ty(FunctionType::Evaluator(all_params_ty_flatten)).build(); set_all_ref_nodes(all_params_flatten.clone(), ev.as_owner()); @@ -233,6 +235,7 @@ impl<'a> MirBuilder<'a> { ast_func: &'a ast::Function, ) -> Result, CompileError> { let mut params = Vec::new(); + let mut params_ty = Vec::new(); self.root_name = Some(ident); let mut func = Function::builder().span(ast_func.span()); @@ -241,13 +244,17 @@ impl<'a> MirBuilder<'a> { let name = Some(param_ident); let param = self.translate_params_fn(param_ident.span(), name, ty, &mut i)?; params.push(param.clone()); + params_ty.push(Some(*ty)); func = func.parameters(param.clone()); } i += 1; let ret = Parameter::create(i, ast_func.return_type, ast_func.span()); params.push(ret.clone()); - let func = func.return_type(ret).build(); + let func = func + .return_type(ret) + .func_ty(FunctionType::Function(params_ty, Some(ast_func.return_type))) + .build(); set_all_ref_nodes(params.clone(), func.as_owner()); self.mir.constraint_graph_mut().insert_function(*ident, func.clone())?; @@ -486,8 +493,8 @@ impl<'a> MirBuilder<'a> { let for_node = For::create( iterator_nodes.into(), - Op::None(list_comp.span()).into(), - Op::None(list_comp.span()).into(), + Op::None(none::None { span: list_comp.span(), ty: None }).into(), + Op::None(none::None { span: list_comp.span(), ty: None }).into(), list_comp.span(), ); set_all_ref_nodes(params, for_node.as_owner().unwrap()); @@ -701,6 +708,7 @@ impl<'a> MirBuilder<'a> { pc.period(), )), }) + .ty(pc.ty()) .build(); Ok(node) } else if let Some(bus) = self.mir.constraint_graph().get_bus_link(&qual_ident) { @@ -709,6 +717,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::BusAccess(BusAccess::new(bus.clone(), access.offset)), }) + .ty(bus.borrow().ty()) .build(); Ok(node) } else { @@ -790,7 +799,7 @@ impl<'a> MirBuilder<'a> { })?; }, } - return Ok(Op::None(bin_op.span()).into()); + return Ok(Op::None(none::None { span: bin_op.span(), ty: None }).into()); } } } @@ -822,33 +831,34 @@ impl<'a> MirBuilder<'a> { fn translate_call(&mut self, call: &'a ast::Call) -> Result, CompileError> { // First, resolve the callee, panic if it's not resolved let resolved_callee = call.callee.resolved().unwrap(); - if call.is_builtin() { // If it's a fold operator (Sum / Prod), handle it match call.callee.as_ref().name() { symbols::Sum => { assert_eq!(call.args.len(), 1); + let acc = ast::ConstantExpr::Scalar(0); let iterator_node = self.translate_expr(call.args.first().unwrap())?; - let accumulator_node = - self.translate_const(&ast::ConstantExpr::Scalar(0), call.span())?; + let accumulator_node = self.translate_const(&acc, call.span())?; let node = Fold::builder() .span(call.span()) .iterator(iterator_node) .operator(FoldOperator::Add) .initial_value(accumulator_node) + .ty(acc.ty()) .build(); Ok(node) }, symbols::Prod => { assert_eq!(call.args.len(), 1); + let acc = ast::ConstantExpr::Scalar(1); let iterator_node = self.translate_expr(call.args.first().unwrap())?; - let accumulator_node = - self.translate_const(&ast::ConstantExpr::Scalar(1), call.span())?; + let accumulator_node = self.translate_const(&acc, call.span())?; let node = Fold::builder() .span(call.span()) .iterator(iterator_node) .operator(FoldOperator::Mul) .initial_value(accumulator_node) + .ty(acc.ty()) .build(); Ok(node) }, @@ -912,6 +922,7 @@ impl<'a> MirBuilder<'a> { } // safe to unwrap because we know it is a Function due to get_function let callee_ref = callee.as_function().unwrap(); + if callee_ref.parameters.len() != arg_nodes.len() { self.diagnostics .diagnostic(Severity::Error) @@ -934,6 +945,27 @@ impl<'a> MirBuilder<'a> { .emit(); return Err(CompileError::Failed); } + + let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); + let arg_kinds_refs = arg_kinds.iter().collect::>(); + if callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("arguments typing mismatch") + .with_primary_label( + call.span(), + format!("called function with arguments {:?}", arg_kinds_refs), + ) + .with_secondary_label( + call.callee.span(), + format!( + "this functions has parameters: {:?}", + callee_ref.func_ty.params() + ), + ) + .emit(); + return Err(CompileError::Failed); + } } else if let Some(callee) = self.mir.constraint_graph().get_evaluator_root(&resolved_callee) { @@ -971,6 +1003,26 @@ impl<'a> MirBuilder<'a> { .emit(); return Err(CompileError::Failed); } + let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); + let arg_kinds_refs = arg_kinds.iter().collect::>(); + if callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { + self.diagnostics + .diagnostic(Severity::Error) + .with_message("arguments typing mismatch") + .with_primary_label( + call.span(), + format!("called evaluator with arguments {:?}", arg_kinds_refs), + ) + .with_secondary_label( + call.callee.span(), + format!( + "this evaluator has parameters: {:?}", + callee_ref.func_ty.params() + ), + ) + .emit(); + return Err(CompileError::Failed); + } } else { panic!("Unknown function or evaluator: {:?}", resolved_callee); } @@ -1033,7 +1085,9 @@ impl<'a> MirBuilder<'a> { scalar_expr: &'a ast::ScalarExpr, ) -> Result, CompileError> { match scalar_expr { - ast::ScalarExpr::Const(c) => self.translate_scalar_const(c.item, c.span()), + ast::ScalarExpr::Const(c) => { + self.translate_scalar_const(c.item, c.span(), scalar_expr.scalar_ty()) + }, ast::ScalarExpr::SymbolAccess(s) => self.translate_symbol_access(s), ast::ScalarExpr::BoundedSymbolAccess(s) => self.translate_bounded_symbol_access(s), ast::ScalarExpr::Binary(b) => self.translate_binary_op(b), @@ -1055,12 +1109,13 @@ impl<'a> MirBuilder<'a> { &mut self, c: u64, span: SourceSpan, + sty: Option, ) -> Result, CompileError> { let value = SpannedMirValue { value: MirValue::Constant(ConstantValue::Felt(c)), span, }; - let node = Value::builder().value(value).build(); + let node = Value::builder().value(value).ty(ty!(sty)).build(); Ok(node) } @@ -1136,7 +1191,7 @@ impl<'a> MirBuilder<'a> { bus_op = bus_op.args(arg_node); } // Latch is unknown at this point, will be set later in translate_bus_enforce - let bus_op = bus_op.latch(1.into()).build(); + let bus_op = bus_op.latch(1.into()).ty(ast_bus_op.ty()).build(); Ok(bus_op) } @@ -1146,9 +1201,13 @@ impl<'a> MirBuilder<'a> { span: SourceSpan, ) -> Result, CompileError> { match c { - ast::ConstantExpr::Scalar(s) => self.translate_scalar_const(*s, span), - ast::ConstantExpr::Vector(v) => self.translate_vector_const(v.clone(), span), - ast::ConstantExpr::Matrix(m) => self.translate_matrix_const(m.clone(), span), + ast::ConstantExpr::Scalar(s) => self.translate_scalar_const(*s, span, c.scalar_ty()), + ast::ConstantExpr::Vector(v) => { + self.translate_vector_const(v.clone(), span, c.scalar_ty()) + }, + ast::ConstantExpr::Matrix(m) => { + self.translate_matrix_const(m.clone(), span, c.scalar_ty()) + }, } } @@ -1156,10 +1215,11 @@ impl<'a> MirBuilder<'a> { &mut self, v: Vec, span: SourceSpan, + sty: Option, ) -> Result, CompileError> { let mut node = Vector::builder().size(v.len()).span(span); for value in v.iter() { - let value_node = self.translate_scalar_const(*value, span)?; + let value_node = self.translate_scalar_const(*value, span, sty)?; node = node.elements(value_node); } Ok(node.build()) @@ -1169,10 +1229,11 @@ impl<'a> MirBuilder<'a> { &mut self, m: Vec>, span: SourceSpan, + sty: Option, ) -> Result, CompileError> { let mut node = Matrix::builder().size(m.len()).span(span); for row in m.iter() { - let row_node = self.translate_vector_const(row.clone(), span)?; + let row_node = self.translate_vector_const(row.clone(), span, sty)?; node = node.elements(row_node); } let node = node.build(); @@ -1194,6 +1255,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::TraceAccess(trace_access), }) + .ty(trace_access.ty()) .build()); } @@ -1203,6 +1265,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::TraceAccessBinding(tab), }) + .ty(tab.ty()) .build()); } @@ -1240,6 +1303,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::TraceAccess(trace_access), }) + .ty(trace_access.ty()) .build()); } @@ -1250,6 +1314,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::TraceAccessBinding(tab), }) + .ty(tab.ty()) .build()); } @@ -1260,6 +1325,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::PublicInput(public_input_access), }) + .ty(public_input_access.ty()) .build()); }, (None, Some(public_input_table_access)) => { @@ -1268,6 +1334,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::PublicInputTable(public_input_table_access), }) + .ty(public_input_table_access.ty()) .build()); }, _ => {}, diff --git a/mir/src/passes/unrolling/unrolling_first_pass.rs b/mir/src/passes/unrolling/unrolling_first_pass.rs index 970f535c5..ffe99e334 100644 --- a/mir/src/passes/unrolling/unrolling_first_pass.rs +++ b/mir/src/passes/unrolling/unrolling_first_pass.rs @@ -1,15 +1,15 @@ use std::{collections::HashMap, ops::Deref}; use air_parser::ast::AccessType; +use air_types::{Type, ty}; use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Spanned}; -use air_types::{ty, Type}; use crate::{ CompileError, ir::{ Accessor, Add, BackLink, Boundary, ConstantValue, Enf, Exp, FoldOperator, Graph, Link, - Matrix, MirValue, Mul, Node, Op, Owner, Parameter, Parent, RandomInputs, - SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Value, Vector, + Matrix, MirValue, Mul, Node, Op, Owner, Parameter, Parent, RandomInputs, SpannedMirValue, + Sub, TraceAccess, TraceAccessBinding, Value, Vector, }, passes::{ Visitor, diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index ba78ce5cc..c3ab4120e 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -723,11 +723,11 @@ impl VisitMut for SemanticAnalysis<'_> { // // * Must be trace bindings or aliases of same // * Must match the type signature of the callee - if let Ok(ty) = callee_binding_ty { - if let BindingType::Evaluator(params) = ty.item { - for (arg, param) in expr.args.iter().zip(params.iter()) { - self.validate_evaluator_argument(expr.span(), arg, param)?; - } + if let Ok(ty) = callee_binding_ty + && let BindingType::Evaluator(params) = ty.item + { + for (arg, param) in expr.args.iter().zip(params.iter()) { + self.validate_evaluator_argument(expr.span(), arg, param)?; } } @@ -1577,17 +1577,17 @@ impl SemanticAnalysis<'_> { // // If no type is known, a diagnostic is already emitted, so proceed // as if it is valid - if let Some(ty) = access.column.ty.as_ref() { - if !ty.is_scalar() { - // Invalid constraint, only scalar values are allowed - self.type_mismatch( - Some(ty), - access.span(), - &ty!(_).unwrap(), - found.span(), - constraint_span, - )?; - } + if let Some(ty) = access.column.ty.as_ref() + && !ty.is_scalar() + { + // Invalid constraint, only scalar values are allowed + self.type_mismatch( + Some(ty), + access.span(), + &ty!(_).unwrap(), + found.span(), + constraint_span, + )?; } // Verify that the right-hand expression evaluates to a scalar diff --git a/types/src/types.rs b/types/src/types.rs index 88717874e..0fa0d8a8a 100644 --- a/types/src/types.rs +++ b/types/src/types.rs @@ -184,6 +184,12 @@ pub enum FunctionType { Function(Vec>, Option), } +impl Default for FunctionType { + fn default() -> Self { + Self::Evaluator(vec![]) + } +} + impl FunctionType { pub fn params(&self) -> &[Option] { match self { @@ -200,7 +206,7 @@ impl FunctionType { } pub fn check_args_kinds(&self, args: &[&Kind]) -> bool { - eprintln!("Checking function type {} against params {:?}", self, args); + eprintln!("Checking function type {self} against params {args:?}"); let params = self.params(); if params.len() != args.len() { return false; @@ -236,7 +242,7 @@ impl core::fmt::Display for FunctionType { )?; f.write_str(") -> ")?; if let Some(ret_type) = ret { - write!(f, "{}", ret_type) + write!(f, "{ret_type}") } else { f.write_str("?") } @@ -267,6 +273,12 @@ pub enum BinType { Exp(Option, Option, Option), } +impl Default for BinType { + fn default() -> Self { + Self::Eq(None, None, None) + } +} + impl BinType { pub fn lhs(&self) -> Option { match self { @@ -693,6 +705,12 @@ pub enum Kind { Callable(FunctionType), } +impl Default for Kind { + fn default() -> Self { + Self::Value(None) + } +} + impl core::fmt::Display for Kind { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { From 46f80654f9d21e118048bdb8b11f9bb86a9aff7c Mon Sep 17 00:00:00 2001 From: Leo-Besancon Date: Mon, 15 Sep 2025 09:25:06 +0200 Subject: [PATCH 32/42] refactor: rename none::None to stale::Stale --- mir/src/ir/node.rs | 6 +++--- mir/src/ir/nodes/mod.rs | 4 ++-- mir/src/ir/nodes/op.rs | 4 ++-- mir/src/ir/nodes/root.rs | 6 +++--- mir/src/ir/nodes/{none.rs => stale.rs} | 12 ++++++------ mir/src/ir/owner.rs | 6 +++--- mir/src/passes/translate.rs | 15 +++++---------- 7 files changed, 24 insertions(+), 29 deletions(-) rename mir/src/ir/nodes/{none.rs => stale.rs} (75%) diff --git a/mir/src/ir/node.rs b/mir/src/ir/node.rs index 30fe5925d..90a800326 100644 --- a/mir/src/ir/node.rs +++ b/mir/src/ir/node.rs @@ -2,7 +2,7 @@ use std::ops::Deref; use miden_diagnostics::Spanned; -use crate::ir::{BackLink, Child, Link, None, Op, Owner, Parent, Root}; +use crate::ir::{BackLink, Child, Link, Stale, Op, Owner, Parent, Root}; /// All the nodes that can be in the MIR Graph /// Combines all [Root] and [Op] variants @@ -33,7 +33,7 @@ pub enum Node { BusOp(BackLink), Parameter(BackLink), Value(BackLink), - None(None), + None(Stale), } impl Default for Node { @@ -230,7 +230,7 @@ impl Link { }; } else { // If the [Node] is stale, we set it to None - to_update = Node::None(None { span: self.span(), ty: None }); + to_update = Node::None(Stale { span: self.span(), ty: None }); } *self.borrow_mut() = to_update; diff --git a/mir/src/ir/nodes/mod.rs b/mir/src/ir/nodes/mod.rs index ae26a94d3..2ca8f45bb 100644 --- a/mir/src/ir/nodes/mod.rs +++ b/mir/src/ir/nodes/mod.rs @@ -1,4 +1,4 @@ -pub mod none; +pub mod stale; mod op; mod ops; mod root; @@ -6,7 +6,7 @@ mod roots; use std::cell::{Ref, RefMut}; -pub use none::*; +pub use stale::*; pub use op::Op; pub use ops::*; pub use root::Root; diff --git a/mir/src/ir/nodes/op.rs b/mir/src/ir/nodes/op.rs index 6b269f1dc..d2751987e 100644 --- a/mir/src/ir/nodes/op.rs +++ b/mir/src/ir/nodes/op.rs @@ -8,7 +8,7 @@ use miden_diagnostics::Spanned; use crate::ir::{ Accessor, Add, BackLink, Boundary, BuilderHook, BusOp, Call, Child, ConstantValue, Enf, Exp, - Fold, For, If, Link, Matrix, MirValue, Mul, Node, None, Owner, Parameter, Parent, Singleton, + Fold, For, If, Link, Matrix, MirValue, Mul, Node, Stale, Owner, Parameter, Parent, Singleton, SpannedMirValue, Sub, Value, Vector, get_inner, get_inner_mut, }; @@ -34,7 +34,7 @@ pub enum Op { BusOp(BusOp), Parameter(Parameter), Value(Value), - None(None), + None(Stale), } impl BuilderHook for Op { diff --git a/mir/src/ir/nodes/root.rs b/mir/src/ir/nodes/root.rs index 29f6ef04f..53dc61b85 100644 --- a/mir/src/ir/nodes/root.rs +++ b/mir/src/ir/nodes/root.rs @@ -7,7 +7,7 @@ use air_types::Typing; use miden_diagnostics::Spanned; use crate::ir::{ - BackLink, BuilderHook, Evaluator, Function, Link, Node, None, Op, Owner, Parent, Singleton, + BackLink, BuilderHook, Evaluator, Function, Link, Node, Stale, Op, Owner, Parent, Singleton, get_inner, get_inner_mut, }; @@ -18,7 +18,7 @@ use crate::ir::{ pub enum Root { Function(Function), Evaluator(Evaluator), - None(None), + None(Stale), } impl BuilderHook for Root { @@ -33,7 +33,7 @@ impl BuilderHook for Root { impl Default for Root { fn default() -> Self { - Root::None(None::default()) + Root::None(Stale::default()) } } diff --git a/mir/src/ir/nodes/none.rs b/mir/src/ir/nodes/stale.rs similarity index 75% rename from mir/src/ir/nodes/none.rs rename to mir/src/ir/nodes/stale.rs index 640fe9dfe..501989597 100644 --- a/mir/src/ir/nodes/none.rs +++ b/mir/src/ir/nodes/stale.rs @@ -2,32 +2,32 @@ use air_types::{ScalarTypeMut, Type, TypeMut, Typing}; use miden_diagnostics::{SourceSpan, Spanned}; #[derive(Clone, PartialEq, Eq, Debug, Hash, Spanned)] -pub struct None { +pub struct Stale { #[span] pub span: SourceSpan, pub ty: Option, } -impl ScalarTypeMut for None { +impl ScalarTypeMut for Stale { fn scalar_ty_mut(&mut self) -> &mut Option { self.ty.scalar_ty_mut() } } -impl TypeMut for None { +impl TypeMut for Stale { fn ty_mut(&mut self) -> &mut Option { self.ty.ty_mut() } } -impl Typing for None { +impl Typing for Stale { fn ty(&self) -> Option { self.ty.ty() } } -impl Default for None { +impl Default for Stale { fn default() -> Self { - None { span: Default::default(), ty: None } + Stale { span: Default::default(), ty: None } } } diff --git a/mir/src/ir/owner.rs b/mir/src/ir/owner.rs index f2db53dd5..ec907f06a 100644 --- a/mir/src/ir/owner.rs +++ b/mir/src/ir/owner.rs @@ -2,7 +2,7 @@ use std::ops::Deref; use miden_diagnostics::Spanned; -use crate::ir::{BackLink, Child, Link, Node, None, Op, Parent, Root}; +use crate::ir::{BackLink, Child, Link, Node, Stale, Op, Parent, Root}; /// The nodes that can own [Op] nodes /// The [Owner] enum does not own it's inner struct to avoid reference cycles, @@ -30,7 +30,7 @@ pub enum Owner { Enf(BackLink), For(BackLink), If(BackLink), - None(None), + None(Stale), } impl Parent for Owner { @@ -209,7 +209,7 @@ impl Link { }; } else { // If the [Owner] is stale, we set it to None - to_update = Owner::None(None { span: self.span(), ty: None }); + to_update = Owner::None(Stale { span: self.span(), ty: None }); } *self.borrow_mut() = to_update; diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 0f6eaaf30..193595714 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -11,14 +11,9 @@ use air_types::*; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; use crate::{ - CompileError, ir::{ - Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, - Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, - MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, Root, - SpannedMirValue, Sub, TraceAccess, TraceAccessBinding, Type, Value, Vector, none, - }, - passes::duplicate_node, + Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, Root, SpannedMirValue, Stale, Sub, TraceAccess, TraceAccessBinding, Type, Value, Vector + }, passes::duplicate_node, CompileError }; /// This pass transforms a given [ast::Program] into a Middle Intermediate Representation ([Mir]) @@ -493,8 +488,8 @@ impl<'a> MirBuilder<'a> { let for_node = For::create( iterator_nodes.into(), - Op::None(none::None { span: list_comp.span(), ty: None }).into(), - Op::None(none::None { span: list_comp.span(), ty: None }).into(), + Op::None(Stale { span: list_comp.span(), ty: None }).into(), + Op::None(Stale { span: list_comp.span(), ty: None }).into(), list_comp.span(), ); set_all_ref_nodes(params, for_node.as_owner().unwrap()); @@ -799,7 +794,7 @@ impl<'a> MirBuilder<'a> { })?; }, } - return Ok(Op::None(none::None { span: bin_op.span(), ty: None }).into()); + return Ok(Op::None(Stale { span: bin_op.span(), ty: None }).into()); } } } From f68eb250a94bb435da95ae16dfa2247de4297999 Mon Sep 17 00:00:00 2001 From: Leo-Besancon Date: Mon, 15 Sep 2025 09:25:25 +0200 Subject: [PATCH 33/42] fix: us correct op for infer_bin_ty_* --- mir/src/ir/nodes/ops/exp.rs | 2 +- mir/src/ir/nodes/ops/mul.rs | 2 +- mir/src/ir/nodes/ops/sub.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/mir/src/ir/nodes/ops/exp.rs b/mir/src/ir/nodes/ops/exp.rs index 3bda91cf9..7896dcd3a 100644 --- a/mir/src/ir/nodes/ops/exp.rs +++ b/mir/src/ir/nodes/ops/exp.rs @@ -40,7 +40,7 @@ impl Typing for Exp { impl BuilderHook for Exp { fn finalize_hook(&mut self) { self._bin_ty = BinType::Exp(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); - let res = self._bin_ty.infer_bin_ty_add().unwrap(); + let res = self._bin_ty.infer_bin_ty_exp().unwrap(); *self._bin_ty.result_mut() = res; } } diff --git a/mir/src/ir/nodes/ops/mul.rs b/mir/src/ir/nodes/ops/mul.rs index b4a678413..6e9158795 100644 --- a/mir/src/ir/nodes/ops/mul.rs +++ b/mir/src/ir/nodes/ops/mul.rs @@ -38,7 +38,7 @@ impl Typing for Mul { impl BuilderHook for Mul { fn finalize_hook(&mut self) { self._bin_ty = BinType::Mul(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); - let res = self._bin_ty.infer_bin_ty_add().unwrap(); + let res = self._bin_ty.infer_bin_ty_mul().unwrap(); *self._bin_ty.result_mut() = res; } } diff --git a/mir/src/ir/nodes/ops/sub.rs b/mir/src/ir/nodes/ops/sub.rs index 5647a62e1..8e2c8c455 100644 --- a/mir/src/ir/nodes/ops/sub.rs +++ b/mir/src/ir/nodes/ops/sub.rs @@ -38,7 +38,7 @@ impl Typing for Sub { impl BuilderHook for Sub { fn finalize_hook(&mut self) { self._bin_ty = BinType::Sub(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); - let res = self._bin_ty.infer_bin_ty_add().unwrap(); + let res = self._bin_ty.infer_bin_ty_sub().unwrap(); *self._bin_ty.result_mut() = res; } } From 7efee3194bc13249ec0a5ff318e6ac1b78f1fcc2 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 17 Sep 2025 18:03:32 +0200 Subject: [PATCH 34/42] refactor(types): change *_mut api + support for RefCell, Ref, and RefMut --- types/src/lib.rs | 130 +++++++++++++++++++++++++++++++++++++---------- 1 file changed, 103 insertions(+), 27 deletions(-) diff --git a/types/src/lib.rs b/types/src/lib.rs index 34b718829..2f15722ae 100644 --- a/types/src/lib.rs +++ b/types/src/lib.rs @@ -1,6 +1,10 @@ mod types; -use std::fmt::Debug; +use std::{ + cell::{Ref, RefCell, RefMut}, + fmt::Debug, + ops::{Deref, DerefMut}, +}; use miden_diagnostics::{SourceSpan, Span}; pub use types::*; @@ -322,16 +326,16 @@ pub trait Typing { } pub trait ScalarTypeMut: Typing { - fn scalar_ty_mut(&mut self) -> &mut Option; + fn update_scalar_ty_unchecked(&mut self, new_ty: Option); fn update_scalar_ty(&mut self, new_ty: Option) -> Result<(), TypeError> { let ty = self.scalar_ty(); if ty.is_none() { // WARN: This should only be true before type inference // Any None type should raise a diagnostic after type inference - *self.scalar_ty_mut() = new_ty; + self.update_scalar_ty_unchecked(new_ty); } else if ty.is_scalar_subtype(&new_ty) { // Allow widening of types - *self.scalar_ty_mut() = new_ty; + self.update_scalar_ty_unchecked(new_ty); } else { return Err(TypeError::IncompatibleScalarTypes { lhs: ty, rhs: new_ty, span: None }); } @@ -340,16 +344,16 @@ pub trait ScalarTypeMut: Typing { } pub trait TypeMut: Typing + ScalarTypeMut { - fn ty_mut(&mut self) -> &mut Option; + fn update_ty_unchecked(&mut self, new_ty: Option); fn update_ty(&mut self, new_ty: Option) -> Result<(), TypeError> { let ty = self.ty(); if ty.is_none() { // WARN: This should only be true before type inference // Any None type should raise a diagnostic after type inference - *self.ty_mut() = new_ty; + self.update_ty_unchecked(new_ty); } else if ty.is_subtype(&new_ty) { // Allow widening of types - *self.ty_mut() = new_ty; + self.update_ty_unchecked(new_ty); } else { return Err(TypeError::NotASubtype { lhs: ty, rhs: new_ty, span: None }); } @@ -453,11 +457,11 @@ impl Typing for Type { } impl ScalarTypeMut for Type { - fn scalar_ty_mut(&mut self) -> &mut Option { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { match self { - Type::Scalar(st) => st, - Type::Vector(st, _) => st, - Type::Matrix(st, ..) => st, + Type::Scalar(st) => *st = new_ty, + Type::Vector(st, _) => *st = new_ty, + Type::Matrix(st, ..) => *st = new_ty, } } } @@ -475,14 +479,14 @@ impl Typing for FunctionType { } impl ScalarTypeMut for BinType { - fn scalar_ty_mut(&mut self) -> &mut Option { - self.result_mut().scalar_ty_mut() + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.result_mut().update_scalar_ty_unchecked(new_ty); } } impl TypeMut for BinType { - fn ty_mut(&mut self) -> &mut Option { - self.result_mut() + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.result_mut().update_ty_unchecked(new_ty); } } @@ -507,9 +511,9 @@ impl Typing for BinType { } impl ScalarTypeMut for Kind { - fn scalar_ty_mut(&mut self) -> &mut Option { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { match self { - Kind::Value(ty) => ty.scalar_ty_mut(), + Kind::Value(ty) => ty.update_scalar_ty_unchecked(new_ty), Kind::Aggregate(_) => panic!("Cannot mutate scalar type of an aggregate kind"), Kind::Callable(_) => panic!("Cannot mutate scalar type of a callable kind"), } @@ -517,9 +521,9 @@ impl ScalarTypeMut for Kind { } impl TypeMut for Kind { - fn ty_mut(&mut self) -> &mut Option { + fn update_ty_unchecked(&mut self, new_ty: Option) { match self { - Kind::Value(ty) => ty, + Kind::Value(ty) => ty.update_ty_unchecked(new_ty), Kind::Aggregate(_) => panic!("Cannot mutate type of an aggregate kind"), Kind::Callable(_) => panic!("Cannot mutate type of a callable kind"), } @@ -555,25 +559,25 @@ impl Typing for Kind { } impl ScalarTypeMut for Option { - fn scalar_ty_mut(&mut self) -> &mut Option { - self + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + *self = new_ty; } } impl ScalarTypeMut for Option { - fn scalar_ty_mut(&mut self) -> &mut Option { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { match self { - Some(Type::Scalar(st)) => st, - Some(Type::Vector(st, _)) => st, - Some(Type::Matrix(st, ..)) => st, + Some(Type::Scalar(st)) => *st = new_ty, + Some(Type::Vector(st, _)) => *st = new_ty, + Some(Type::Matrix(st, ..)) => *st = new_ty, None => panic!("Cannot mutate scalar type of None"), } } } impl TypeMut for Option { - fn ty_mut(&mut self) -> &mut Option { - self + fn update_ty_unchecked(&mut self, new_ty: Option) { + *self = new_ty; } } @@ -598,6 +602,78 @@ impl Typing for Box { } } +impl ScalarTypeMut for RefCell { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.borrow_mut().update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for RefCell { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.borrow_mut().update_ty_unchecked(new_ty); + } +} + +impl Typing for RefCell { + fn kind(&self) -> Option { + self.borrow().kind() + } + fn ty(&self) -> Option { + self.borrow().ty() + } + fn scalar_ty(&self) -> Option { + self.borrow().scalar_ty() + } +} + +impl ScalarTypeMut for RefMut<'_, T> { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.deref_mut().update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for RefMut<'_, T> { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.deref_mut().update_ty_unchecked(new_ty); + } +} + +impl Typing for RefMut<'_, T> { + fn kind(&self) -> Option { + self.deref().kind() + } + fn ty(&self) -> Option { + self.deref().ty() + } + fn scalar_ty(&self) -> Option { + self.deref().scalar_ty() + } +} + +impl Typing for Ref<'_, T> { + fn kind(&self) -> Option { + self.deref().kind() + } + fn ty(&self) -> Option { + self.deref().ty() + } + fn scalar_ty(&self) -> Option { + self.deref().scalar_ty() + } +} + +impl ScalarTypeMut for Span { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.item.update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for Span { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.item.update_ty_unchecked(new_ty); + } +} + impl Typing for Span { fn kind(&self) -> Option { self.item.kind() From e21c2d2557944d997cd23db177d3281c680575ac Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 17 Sep 2025 18:07:07 +0200 Subject: [PATCH 35/42] refactor(types): update ast, expose Typ* traits through Link --- mir/src/ir/link.rs | 19 +++++++++++++++++++ parser/src/ast/expression.rs | 12 ++++++++---- 2 files changed, 27 insertions(+), 4 deletions(-) diff --git a/mir/src/ir/link.rs b/mir/src/ir/link.rs index db1cab648..fd92cdee1 100644 --- a/mir/src/ir/link.rs +++ b/mir/src/ir/link.rs @@ -5,6 +5,7 @@ use std::{ rc::{Rc, Weak}, }; +use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; /// A wrapper around a `Rc>` to allow custom trait implementations. @@ -110,6 +111,24 @@ where } } +impl Typing for Link { + fn ty(&self) -> Option { + self.borrow().ty() + } +} + +impl ScalarTypeMut for Link { + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + self.borrow_mut().update_scalar_ty_unchecked(new_ty); + } +} + +impl TypeMut for Link { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.borrow_mut().update_ty_unchecked(new_ty); + } +} + /// A wrapper around a `Option>>` to allow custom trait implementations. /// Used instead of `Link` where a `Link` would create a cyclIc reference. pub struct BackLink { diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index f4bb8a8a3..ad2628869 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -848,13 +848,17 @@ impl Typing for BinaryExpr { } } impl ScalarTypeMut for BinaryExpr { - fn scalar_ty_mut(&mut self) -> &mut Option { - self.bin_ty.as_mut().unwrap().scalar_ty_mut() + fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { + if let Some(bty) = self.bin_ty.as_mut() { + bty.update_scalar_ty_unchecked(new_ty); + } } } impl TypeMut for BinaryExpr { - fn ty_mut(&mut self) -> &mut Option { - self.bin_ty.as_mut().unwrap().ty_mut() + fn update_ty_unchecked(&mut self, new_ty: Option) { + if let Some(bty) = self.bin_ty.as_mut() { + bty.update_ty_unchecked(new_ty); + } } } From d57ae6e1d2c311f1fefe12c126947309e182fcb9 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 18 Sep 2025 14:38:48 +0200 Subject: [PATCH 36/42] feat(types): implement typing for the whole pipeline --- air/src/passes/translate_from_mir.rs | 6 +-- mir/src/ir/bus.rs | 2 +- mir/src/ir/nodes/op.rs | 8 +-- mir/src/ir/nodes/ops/bus_op.rs | 20 ------- mir/src/ir/nodes/ops/value.rs | 52 ++++++++++++++++++- mir/src/passes/translate.rs | 22 +++++--- .../passes/unrolling/unrolling_first_pass.rs | 40 +++++++------- parser/src/ast/declarations.rs | 5 ++ parser/src/ast/expression.rs | 19 +++---- parser/src/ast/statement.rs | 34 ++++++------ types/src/types.rs | 10 ++-- 11 files changed, 129 insertions(+), 89 deletions(-) diff --git a/air/src/passes/translate_from_mir.rs b/air/src/passes/translate_from_mir.rs index ed15f61b6..8851c4ffb 100644 --- a/air/src/passes/translate_from_mir.rs +++ b/air/src/passes/translate_from_mir.rs @@ -610,11 +610,7 @@ impl AirBuilder<'_> { .emit(); return Err(CompileError::Failed); } - MirTraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset, - row_offset: 0, - } + MirTraceAccess::new(trace_access_binding.segment, trace_access_binding.offset, 0) }, SpannedMirValue { value: MirValue::BusAccess(bus_access), .. diff --git a/mir/src/ir/bus.rs b/mir/src/ir/bus.rs index 81728f0dc..57822c7f3 100644 --- a/mir/src/ir/bus.rs +++ b/mir/src/ir/bus.rs @@ -166,7 +166,7 @@ impl Link { for column in columns { bus_op = bus_op.args(column.clone()); } - let bus_op = bus_op.latch(latch.clone()).ty(None).build(); + let bus_op = bus_op.latch(latch.clone()).build(); self.borrow_mut().columns.push(bus_op.clone()); self.borrow_mut().latches.push(latch.clone()); bus_op diff --git a/mir/src/ir/nodes/op.rs b/mir/src/ir/nodes/op.rs index ea38654d5..bed3b76aa 100644 --- a/mir/src/ir/nodes/op.rs +++ b/mir/src/ir/nodes/op.rs @@ -3,7 +3,7 @@ use std::{ ops::{Deref, DerefMut}, }; -use air_types::{ScalarTypeMut, TypeMut, Typing}; +use air_types::*; use miden_diagnostics::Spanned; use crate::ir::{ @@ -175,7 +175,7 @@ impl ScalarTypeMut for Op { Op::Vector(v) => v.update_scalar_ty_unchecked(new_ty), Op::Matrix(m) => m.update_scalar_ty_unchecked(new_ty), Op::Accessor(a) => a.update_scalar_ty_unchecked(new_ty), - Op::BusOp(b) => b.update_scalar_ty_unchecked(new_ty), + Op::BusOp(_) => {}, Op::Parameter(p) => p.update_scalar_ty_unchecked(new_ty), Op::Value(v) => v.update_scalar_ty_unchecked(new_ty), Op::None(n) => n.update_scalar_ty_unchecked(new_ty), @@ -199,7 +199,7 @@ impl TypeMut for Op { Op::Vector(v) => v.update_ty_unchecked(new_ty), Op::Matrix(m) => m.update_ty_unchecked(new_ty), Op::Accessor(a) => a.update_ty_unchecked(new_ty), - Op::BusOp(b) => b.update_ty_unchecked(new_ty), + Op::BusOp(_) => {}, Op::Parameter(p) => p.update_ty_unchecked(new_ty), Op::Value(v) => v.update_ty_unchecked(new_ty), Op::None(n) => n.update_ty_unchecked(new_ty), @@ -223,7 +223,7 @@ impl Typing for Op { Op::Vector(v) => v.ty(), Op::Matrix(m) => m.ty(), Op::Accessor(a) => a.ty(), - Op::BusOp(b) => b.ty(), + Op::BusOp(_) => ty!(?), Op::Parameter(p) => p.ty(), Op::Value(v) => v.ty(), Op::None(n) => n.ty(), diff --git a/mir/src/ir/nodes/ops/bus_op.rs b/mir/src/ir/nodes/ops/bus_op.rs index a7689e745..fd986d37a 100644 --- a/mir/src/ir/nodes/ops/bus_op.rs +++ b/mir/src/ir/nodes/ops/bus_op.rs @@ -1,6 +1,5 @@ use std::hash::Hash; -use air_types::*; use miden_diagnostics::{SourceSpan, Spanned}; use crate::ir::{ @@ -26,25 +25,6 @@ pub struct BusOp { pub _owner: Singleton, #[span] pub span: SourceSpan, - pub ty: Option, -} - -impl ScalarTypeMut for BusOp { - fn update_scalar_ty_unchecked(&mut self, new_ty: Option) { - self.ty.update_scalar_ty_unchecked(new_ty); - } -} - -impl TypeMut for BusOp { - fn update_ty_unchecked(&mut self, new_ty: Option) { - self.ty = new_ty; - } -} - -impl Typing for BusOp { - fn ty(&self) -> Option { - self.ty.ty() - } } impl BuilderHook for BusOp {} diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index 2b796b84f..56535ddf5 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -144,11 +144,30 @@ pub struct TraceAccess { /// For example, if accessing a trace column with `a'`, where `a` is bound to a single column, /// the row offset would be `1`, as the `'` modifier indicates the "next" row. pub row_offset: usize, + /// The type of the value being accessed, if known. + /// Defaults to None until the access is resolved. + /// This should only be a felt or [felt; n] type. + ty: Option, } impl TraceAccess { /// Creates a new [TraceAccess]. pub const fn new(segment: TraceSegmentId, column: TraceColumnIndex, row_offset: usize) -> Self { - Self { segment, column, row_offset } + Self { segment, column, row_offset, ty: None } + } +} +impl ScalarTypeMut for TraceAccess { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} +impl TypeMut for TraceAccess { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} +impl Typing for TraceAccess { + fn ty(&self) -> Option { + self.ty.ty() } } @@ -160,6 +179,11 @@ pub struct TraceAccessBinding { /// The number of columns which are bound pub size: usize, } +impl Typing for TraceAccessBinding { + fn ty(&self) -> Option { + ty!(felt[self.size]) + } +} /// Represents a typed value in the MIR. #[derive(Debug, Eq, PartialEq, Clone, Hash, Spanned)] @@ -188,10 +212,28 @@ pub struct PublicInputAccess { pub name: Identifier, /// The index of the element in the public input to access pub index: usize, + /// The type of the value being accessed, if known. + /// Defaults to None until the access is resolved. + ty: Option, } impl PublicInputAccess { pub const fn new(name: Identifier, index: usize) -> Self { - Self { name, index } + Self { name, index, ty: None } + } +} +impl ScalarTypeMut for PublicInputAccess { + fn update_scalar_ty_unchecked(&mut self, new_sty: Option) { + self.ty.update_scalar_ty_unchecked(new_sty); + } +} +impl TypeMut for PublicInputAccess { + fn update_ty_unchecked(&mut self, new_ty: Option) { + self.ty = new_ty; + } +} +impl Typing for PublicInputAccess { + fn ty(&self) -> Option { + self.ty.ty() } } @@ -230,3 +272,9 @@ impl Default for SpannedMirValue { } } } + +impl Typing for PublicInputTableAccess { + fn ty(&self) -> Option { + ty!(felt[self.num_cols, usize::MAX]) + } +} diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 193595714..cd64199fc 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -11,9 +11,14 @@ use air_types::*; use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned}; use crate::{ + CompileError, ir::{ - Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, Root, SpannedMirValue, Stale, Sub, TraceAccess, TraceAccessBinding, Type, Value, Vector - }, passes::duplicate_node, CompileError + Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, + Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, + MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, Root, + SpannedMirValue, Stale, Sub, TraceAccess, TraceAccessBinding, Type, Value, Vector, + }, + passes::duplicate_node, }; /// This pass transforms a given [ast::Program] into a Middle Intermediate Representation ([Mir]) @@ -712,7 +717,7 @@ impl<'a> MirBuilder<'a> { span: access.span(), value: MirValue::BusAccess(BusAccess::new(bus.clone(), access.offset)), }) - .ty(bus.borrow().ty()) + .ty(ty!(?)) .build(); Ok(node) } else { @@ -1186,7 +1191,7 @@ impl<'a> MirBuilder<'a> { bus_op = bus_op.args(arg_node); } // Latch is unknown at this point, will be set later in translate_bus_enforce - let bus_op = bus_op.latch(1.into()).ty(ast_bus_op.ty()).build(); + let bus_op = bus_op.latch(1.into()).build(); Ok(bus_op) } @@ -1255,12 +1260,13 @@ impl<'a> MirBuilder<'a> { } if let Some(tab) = self.trace_access_binding(access) { + let typ = tab.ty(); return Ok(Value::builder() .value(SpannedMirValue { span: access.span(), value: MirValue::TraceAccessBinding(tab), }) - .ty(tab.ty()) + .ty(typ) .build()); } @@ -1304,12 +1310,13 @@ impl<'a> MirBuilder<'a> { // Otherwise, we check bindings, trace bindings, and public inputs, in that order if let Some(tab) = self.trace_access_binding(access) { + let typ = tab.ty(); return Ok(Value::builder() .value(SpannedMirValue { span: access.span(), value: MirValue::TraceAccessBinding(tab), }) - .ty(tab.ty()) + .ty(typ) .build()); } @@ -1324,12 +1331,13 @@ impl<'a> MirBuilder<'a> { .build()); }, (None, Some(public_input_table_access)) => { + let typ = public_input_table_access.ty(); return Ok(Value::builder() .value(SpannedMirValue { span: access.span(), value: MirValue::PublicInputTable(public_input_table_access), }) - .ty(public_input_table_access.ty()) + .ty(typ) .build()); }, _ => {}, diff --git a/mir/src/passes/unrolling/unrolling_first_pass.rs b/mir/src/passes/unrolling/unrolling_first_pass.rs index ffe99e334..96a53880c 100644 --- a/mir/src/passes/unrolling/unrolling_first_pass.rs +++ b/mir/src/passes/unrolling/unrolling_first_pass.rs @@ -58,22 +58,22 @@ fn unroll_trace_access_binding( if trace_access_binding.size == 1 { Value::create(SpannedMirValue { span, - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset, - row_offset: 0, - }), + value: MirValue::TraceAccess(TraceAccess::new( + trace_access_binding.segment, + trace_access_binding.offset, + 0, + )), }) } else { let mut vec = vec![]; for index in 0..trace_access_binding.size { let val = Value::create(SpannedMirValue { span, - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access_binding.segment, - column: trace_access_binding.offset + index, - row_offset: 0, - }), + value: MirValue::TraceAccess(TraceAccess::new( + trace_access_binding.segment, + trace_access_binding.offset + index, + 0, + )), }); vec.push(val); } @@ -155,11 +155,11 @@ fn unroll_accessor_default_access_type( if let MirValue::TraceAccess(trace_access) = mir_value { let new_node = Value::create(SpannedMirValue { span: value.value.span(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset + accessor_offset, - }), + value: MirValue::TraceAccess(TraceAccess::new( + trace_access.segment, + trace_access.column, + trace_access.row_offset + accessor_offset, + )), }); return Some(new_node); } @@ -187,11 +187,11 @@ fn unroll_accessor_index_access_type( MirValue::TraceAccess(trace_access) => { let new_node = Value::create(SpannedMirValue { span: value.value.span(), - value: MirValue::TraceAccess(TraceAccess { - segment: trace_access.segment, - column: trace_access.column, - row_offset: trace_access.row_offset + accessor_offset, - }), + value: MirValue::TraceAccess(TraceAccess::new( + trace_access.segment, + trace_access.column, + trace_access.row_offset + accessor_offset, + )), }); Some(new_node) }, diff --git a/parser/src/ast/declarations.rs b/parser/src/ast/declarations.rs index fe6df2f63..4a2cc8896 100644 --- a/parser/src/ast/declarations.rs +++ b/parser/src/ast/declarations.rs @@ -306,6 +306,11 @@ impl PartialEq for PeriodicColumn { self.name == other.name && self.values == other.values } } +impl Typing for PeriodicColumn { + fn ty(&self) -> Option { + ty!(felt[self.period()]) + } +} /// Declaration of a public input for an AirScript program. /// diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index ad2628869..8742fe818 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -698,15 +698,6 @@ impl RangeExpr { self.try_into() .expect("attempted to convert non-constant range expression to constant") } - - pub fn ty(&self) -> Option { - match (&self.start, &self.end) { - (RangeBound::Const(start), RangeBound::Const(end)) => { - ty!(uint[end.item.abs_diff(start.item)]) - }, - _ => None, - } - } } impl From for RangeExpr { fn from(range: Range) -> Self { @@ -734,6 +725,16 @@ impl fmt::Display for RangeExpr { write!(f, "{}..{}", &self.start, &self.end) } } +impl Typing for RangeExpr { + fn ty(&self) -> Option { + match (&self.start, &self.end) { + (RangeBound::Const(start), RangeBound::Const(end)) => { + ty!(uint[end.item.abs_diff(start.item)]) + }, + _ => None, + } + } +} #[derive(Hash, Clone, Spanned, PartialEq, Eq, Debug)] pub enum RangeBound { diff --git a/parser/src/ast/statement.rs b/parser/src/ast/statement.rs index 795dbd8d2..3a806dcd7 100644 --- a/parser/src/ast/statement.rs +++ b/parser/src/ast/statement.rs @@ -200,7 +200,24 @@ impl Let { pub fn new(span: SourceSpan, name: Identifier, value: Expr, body: Vec) -> Self { Self { span, name, value, body } } +} +impl Eq for Let {} +impl PartialEq for Let { + fn eq(&self, other: &Self) -> bool { + self.name == other.name && self.value == other.value && self.body == other.body + } +} +impl fmt::Debug for Let { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + f.debug_struct("Let") + .field("name", &self.name) + .field("value", &self.value) + .field("body", &self.body) + .finish() + } +} +impl Typing for Let { /// Return the type of the overall `let` expression. /// /// A `let` with an empty body, or with a body that terminates with a non-expression statement @@ -210,7 +227,7 @@ impl Let { /// For `let` statements with a non-empty body that terminates with an expression, the `let` can /// be used in expression position, producing the value of the terminating expression in its /// body, and having the same type as that value. - pub fn ty(&self) -> Option { + fn ty(&self) -> Option { let mut last = self.body.last(); while let Some(stmt) = last.take() { match stmt { @@ -228,18 +245,3 @@ impl Let { None } } -impl Eq for Let {} -impl PartialEq for Let { - fn eq(&self, other: &Self) -> bool { - self.name == other.name && self.value == other.value && self.body == other.body - } -} -impl fmt::Debug for Let { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - f.debug_struct("Let") - .field("name", &self.name) - .field("value", &self.value) - .field("body", &self.body) - .finish() - } -} diff --git a/types/src/types.rs b/types/src/types.rs index 0fa0d8a8a..015106c41 100644 --- a/types/src/types.rs +++ b/types/src/types.rs @@ -1,6 +1,6 @@ use crate::{TypeError, Typing}; -#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum ScalarType { Felt, Bool, @@ -47,7 +47,7 @@ macro_rules! sty { } /// The types of values which can be represented in an AirScript program -#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum Type { // annotation: sty // where sty is the scalar type @@ -175,7 +175,7 @@ macro_rules! tty { } /// Represents the type signature of a function -#[derive(Hash, Debug, Clone, PartialEq, Eq)] +#[derive(Hash, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum FunctionType { /// An evaluator function, which has no results, and has /// a complex type signature due to the nature of trace bindings @@ -264,7 +264,7 @@ macro_rules! fty { }; } -#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq)] +#[derive(Hash, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] pub enum BinType { Eq(Option, Option, Option), Add(Option, Option, Option), @@ -698,7 +698,7 @@ impl BinType { } } -#[derive(Hash, Debug, Clone, PartialEq, Eq)] +#[derive(Hash, Debug, Clone, PartialEq, Eq, PartialOrd, Ord)] pub enum Kind { Value(Option), Aggregate(Vec>>), From fe751bb5c42ab33f1ee195d3e736e9757353ea76 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 18 Sep 2025 15:15:24 +0200 Subject: [PATCH 37/42] fix(types): fix Trace* Typing impl --- mir/src/ir/nodes/ops/mul.rs | 7 ++++++- mir/src/ir/nodes/ops/value.rs | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mir/src/ir/nodes/ops/mul.rs b/mir/src/ir/nodes/ops/mul.rs index e2cc346cc..2ebfe053c 100644 --- a/mir/src/ir/nodes/ops/mul.rs +++ b/mir/src/ir/nodes/ops/mul.rs @@ -37,7 +37,12 @@ impl Typing for Mul { impl BuilderHook for Mul { fn finalize_hook(&mut self) { - self._bin_ty = BinType::Mul(self.lhs.borrow().ty(), self.rhs.borrow().ty(), None); + eprintln!("Finalizing Mul Op: {:#?}", self); + let lty = self.lhs.borrow().infer_ty(); + let rty = self.rhs.borrow().infer_ty(); + eprintln!("LHS Type: {:#?}", lty); + eprintln!("RHS Type: {:#?}", rty); + self._bin_ty = BinType::Mul(lty.unwrap(), rty.unwrap(), None); let res = self._bin_ty.infer_bin_ty_mul().unwrap(); *self._bin_ty.result_mut() = res; } diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index 56535ddf5..c6366e5d8 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -181,7 +181,11 @@ pub struct TraceAccessBinding { } impl Typing for TraceAccessBinding { fn ty(&self) -> Option { - ty!(felt[self.size]) + if self.size == 1 { + ty!(felt) + } else { + ty!(felt[self.size]) + } } } @@ -275,6 +279,6 @@ impl Default for SpannedMirValue { impl Typing for PublicInputTableAccess { fn ty(&self) -> Option { - ty!(felt[self.num_cols, usize::MAX]) + ty!(felt[usize::MAX, self.num_cols]) } } From 3c18d059a911a67a801ab6a25709b286f206affd Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 18 Sep 2025 17:09:34 +0200 Subject: [PATCH 38/42] fix(types): various bug fixes --- mir/src/passes/translate.rs | 8 ++-- parser/src/sema/binding_type.rs | 5 +-- parser/src/sema/semantic_analysis.rs | 2 + types/src/types.rs | 55 +++++++++++++++++++++------- 4 files changed, 51 insertions(+), 19 deletions(-) diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index cd64199fc..1c1a9eeb8 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -299,7 +299,7 @@ impl<'a> MirBuilder<'a> { Type::Vector(_, size) => { let mut params = Vec::new(); for _ in 0..*size { - let param = Parameter::create(*i, ty!(felt[*size]).unwrap(), span); + let param = Parameter::create(*i, ty!(felt).unwrap(), span); *i += 1; params.push(param); } @@ -758,6 +758,7 @@ impl<'a> MirBuilder<'a> { &mut self, bin_op: &'a ast::BinaryExpr, ) -> Result, CompileError> { + eprintln!("translating binary op: {bin_op:#?}"); let lhs = self.translate_scalar_expr(&bin_op.lhs)?; let rhs = self.translate_scalar_expr(&bin_op.rhs)?; @@ -947,8 +948,9 @@ impl<'a> MirBuilder<'a> { } let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); + eprintln!("arg kinds: {:#?}", arg_kinds); let arg_kinds_refs = arg_kinds.iter().collect::>(); - if callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { + if !callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { self.diagnostics .diagnostic(Severity::Error) .with_message("arguments typing mismatch") @@ -1005,7 +1007,7 @@ impl<'a> MirBuilder<'a> { } let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); let arg_kinds_refs = arg_kinds.iter().collect::>(); - if callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { + if !callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { self.diagnostics .diagnostic(Severity::Error) .with_message("arguments typing mismatch") diff --git a/parser/src/sema/binding_type.rs b/parser/src/sema/binding_type.rs index a75c25f98..6d529a0b6 100644 --- a/parser/src/sema/binding_type.rs +++ b/parser/src/sema/binding_type.rs @@ -53,8 +53,7 @@ impl Typing for BindingType { Self::TraceColumn(tb) | Self::TraceParam(tb) => tb.kind(), Self::Vector(elems) => elems.kind(), Self::PublicInput(ty) => ty.kind(), - // NOTE: this may need to be felt? - Self::PeriodicColumn(_) => Some(kind!(bool)), + Self::PeriodicColumn(_) => Some(kind!(felt)), } } /// Get the value type of this binding, if applicable @@ -67,7 +66,7 @@ impl Typing for BindingType { Self::PeriodicColumn(_) => ty!(felt), Self::Function(_) => None, Self::Evaluator(_) => None, - Self::Bus(_) => ty!(felt), + Self::Bus(_) => None, } } } diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index c3ab4120e..fa30bbb26 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -1876,6 +1876,7 @@ impl SemanticAnalysis<'_> { } fn expr_binding_type(&self, expr: &Expr) -> Result { + eprintln!("sema::semantic_analysis::expr_binding_type"); match expr { Expr::Const(constant) => { Ok(BindingType::Local(constant.ty().expect("constant type should be known"))) @@ -1899,6 +1900,7 @@ impl SemanticAnalysis<'_> { Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)), Expr::Binary(be) => Ok(BindingType::Local(be.ty().or(ty!(felt)).unwrap())), Expr::ListComprehension(lc) => { + eprintln!("list comprehension: {lc:#?}"); match lc.ty { Some(ty) => Ok(BindingType::Local(ty)), None => { diff --git a/types/src/types.rs b/types/src/types.rs index 015106c41..5fb437638 100644 --- a/types/src/types.rs +++ b/types/src/types.rs @@ -206,20 +206,49 @@ impl FunctionType { } pub fn check_args_kinds(&self, args: &[&Kind]) -> bool { - eprintln!("Checking function type {self} against params {args:?}"); - let params = self.params(); - if params.len() != args.len() { - return false; - } - for (arg_ty, param_kind) in args.iter().zip(params.iter()) { - eprintln!(" Checking arg_ty {arg_ty:?} against param_kind {param_kind:?}"); - if !arg_ty.is_subtype(param_kind) { - eprintln!(" Failed!: {arg_ty:?} is not a subtype of {param_kind:?}"); - return false; - } + match self { + Self::Function(params, _) => { + if params.len() != args.len() { + return false; + } + for (arg_ty, param_kind) in args.iter().zip(params.iter()) { + if !arg_ty.is_subtype(param_kind) { + return false; + } + } + true + }, + Self::Evaluator(params) => { + // Only check that the number of columns match + // since evaluator arguments get matched as a vec of felt + let mut params_len = 0; + for param in params.iter() { + match param { + Some(Type::Scalar(_)) => params_len += 1, + Some(Type::Vector(_, len)) => params_len += *len, + Some(Type::Matrix(_, _, _)) => { + unreachable!("Evaluator functions cannot have matrix parameters") + }, + None => unreachable!("Evaluator functions cannot have untyped parameters"), + } + } + let mut args_len = 0; + for arg in args.iter() { + match arg.ty() { + Some(Type::Scalar(_)) => args_len += 1, + Some(Type::Vector(_, len)) => args_len += len, + Some(Type::Matrix(_, _, _)) => { + unreachable!("Evaluator functions cannot have matrix arguments") + }, + None => unreachable!("Evaluator functions cannot have untyped arguments"), + } + } + if params_len != args_len { + return false; + } + true + }, } - eprintln!(" Success!"); - true } } From 49339d6a10b9c74a1b5bc57d7c8c6bcdabe63ba4 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Thu, 18 Sep 2025 17:20:51 +0200 Subject: [PATCH 39/42] fix(types): more bug fixes, remove unused debugging --- mir/src/ir/nodes/ops/mul.rs | 3 --- mir/src/passes/translate.rs | 2 -- parser/src/sema/semantic_analysis.rs | 2 -- types/src/types.rs | 12 ++++-------- 4 files changed, 4 insertions(+), 15 deletions(-) diff --git a/mir/src/ir/nodes/ops/mul.rs b/mir/src/ir/nodes/ops/mul.rs index 2ebfe053c..c239174d5 100644 --- a/mir/src/ir/nodes/ops/mul.rs +++ b/mir/src/ir/nodes/ops/mul.rs @@ -37,11 +37,8 @@ impl Typing for Mul { impl BuilderHook for Mul { fn finalize_hook(&mut self) { - eprintln!("Finalizing Mul Op: {:#?}", self); let lty = self.lhs.borrow().infer_ty(); let rty = self.rhs.borrow().infer_ty(); - eprintln!("LHS Type: {:#?}", lty); - eprintln!("RHS Type: {:#?}", rty); self._bin_ty = BinType::Mul(lty.unwrap(), rty.unwrap(), None); let res = self._bin_ty.infer_bin_ty_mul().unwrap(); *self._bin_ty.result_mut() = res; diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 1c1a9eeb8..d568b6e35 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -758,7 +758,6 @@ impl<'a> MirBuilder<'a> { &mut self, bin_op: &'a ast::BinaryExpr, ) -> Result, CompileError> { - eprintln!("translating binary op: {bin_op:#?}"); let lhs = self.translate_scalar_expr(&bin_op.lhs)?; let rhs = self.translate_scalar_expr(&bin_op.rhs)?; @@ -948,7 +947,6 @@ impl<'a> MirBuilder<'a> { } let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); - eprintln!("arg kinds: {:#?}", arg_kinds); let arg_kinds_refs = arg_kinds.iter().collect::>(); if !callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { self.diagnostics diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index fa30bbb26..c3ab4120e 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -1876,7 +1876,6 @@ impl SemanticAnalysis<'_> { } fn expr_binding_type(&self, expr: &Expr) -> Result { - eprintln!("sema::semantic_analysis::expr_binding_type"); match expr { Expr::Const(constant) => { Ok(BindingType::Local(constant.ty().expect("constant type should be known"))) @@ -1900,7 +1899,6 @@ impl SemanticAnalysis<'_> { Expr::Call(Call { ty: Some(ty), .. }) => Ok(BindingType::Local(*ty)), Expr::Binary(be) => Ok(BindingType::Local(be.ty().or(ty!(felt)).unwrap())), Expr::ListComprehension(lc) => { - eprintln!("list comprehension: {lc:#?}"); match lc.ty { Some(ty) => Ok(BindingType::Local(ty)), None => { diff --git a/types/src/types.rs b/types/src/types.rs index 5fb437638..d513f5ac3 100644 --- a/types/src/types.rs +++ b/types/src/types.rs @@ -227,9 +227,9 @@ impl FunctionType { Some(Type::Scalar(_)) => params_len += 1, Some(Type::Vector(_, len)) => params_len += *len, Some(Type::Matrix(_, _, _)) => { - unreachable!("Evaluator functions cannot have matrix parameters") + return false; }, - None => unreachable!("Evaluator functions cannot have untyped parameters"), + None => return false, } } let mut args_len = 0; @@ -238,9 +238,9 @@ impl FunctionType { Some(Type::Scalar(_)) => args_len += 1, Some(Type::Vector(_, len)) => args_len += len, Some(Type::Matrix(_, _, _)) => { - unreachable!("Evaluator functions cannot have matrix arguments") + return false; }, - None => unreachable!("Evaluator functions cannot have untyped arguments"), + None => return false, } } if params_len != args_len { @@ -706,20 +706,16 @@ impl BinType { /// - a ? to any power is still a ? pub fn infer_bin_ty_exp(&self) -> Result, TypeError> { if let Some(ret) = self.result() { - eprintln!("infer_bin_ty_exp: returning cached result {ret:?}"); return Ok(Some(ret)); } let lhs = self.lhs(); let rhs = self.rhs(); - eprintln!("infer_bin_ty_exp: lhs = {lhs:?}, rhs = {rhs:?}"); if !((lhs.is_scalar() | lhs.is_none()) && (rhs.is_scalar() | rhs.is_none())) { return Err(TypeError::IncompatibleBinOp { bin_ty: *self, span: None }); } - eprintln!(" MADE IT PAST THE SHAPE CHECK"); match self { bty!(any ^ uint) => Ok(lhs), bty!(any ^ felt) | bty!(any ^ bool) | bty!(any ^ _) | bty!(any ^ ?) => { - eprintln!(" ERROR: any ^ !uint"); Err(TypeError::NonConstantExponent { bin_ty: *self, span: None }) }, _ => unreachable!("Undefined case for infer_bin_ty_exp: {self}"), From 7d49a858daf3f99bb072a30331f1a599122815a1 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Fri, 19 Sep 2025 15:16:46 +0200 Subject: [PATCH 40/42] fix(types): fix Evaluator argument types --- mir/src/passes/translate.rs | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index d568b6e35..57277db79 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -3,7 +3,7 @@ use std::ops::Deref; use air_parser::{ LexicalScope, - ast::{self, AccessType, TraceSegmentId}, + ast::{self, Access, AccessType, TraceSegmentId}, symbols, }; use air_pass::Pass; @@ -160,11 +160,11 @@ impl<'a> MirBuilder<'a> { let span = binding.name.map_or(SourceSpan::UNKNOWN, |n| n.span()); let params = self.translate_params_ev(span, binding.name.as_ref(), &binding.ty, &mut i)?; - + let param_ty = binding.ty.access(AccessType::Index(0)).ok().or(ty!(felt)); for param in params { all_params_flatten_for_trace_segment.push(param.clone()); all_params_flatten.push(param.clone()); - all_params_ty_flatten.push(binding.ty()); + all_params_ty_flatten.push(param_ty); } } @@ -1003,7 +1003,19 @@ impl<'a> MirBuilder<'a> { .emit(); return Err(CompileError::Failed); } - let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); + let mut arg_kinds = vec![]; + if let Some(first_arg) = arg_nodes.first() { + let Some(v) = first_arg.as_vector() else { + unreachable!( + "expected first argument to be a vector, got {:#?}", + first_arg + ); + }; + for element in v.elements.borrow().iter() { + arg_kinds.push(element.kind().unwrap()); + } + } + // let arg_kinds = arg_nodes.iter().map(|arg| arg.kind().unwrap()).collect::>(); let arg_kinds_refs = arg_kinds.iter().collect::>(); if !callee_ref.func_ty.check_args_kinds(&arg_kinds_refs) { self.diagnostics From 4020b8a6f9e923114baae8cfb83b9aae5696f411 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 24 Sep 2025 12:28:49 +0200 Subject: [PATCH 41/42] feat(mir): Cast primitive + translate assert_bool --- air/src/passes/translate_from_mir.rs | 6 +-- mir/src/ir/node.rs | 12 ++++- mir/src/ir/nodes/op.rs | 50 ++++++++++++++++-- mir/src/ir/nodes/ops/cast.rs | 51 +++++++++++++++++++ mir/src/ir/nodes/ops/mod.rs | 2 + mir/src/ir/nodes/ops/value.rs | 23 +++++++-- mir/src/ir/owner.rs | 13 ++++- mir/src/ir/quad_eval.rs | 3 +- mir/src/passes/constant_propagation.rs | 11 ++-- mir/src/passes/mod.rs | 18 ++++++- mir/src/passes/translate.rs | 16 +++--- mir/src/passes/unrolling/match_optimizer.rs | 2 +- .../passes/unrolling/unrolling_first_pass.rs | 17 +++++-- mir/src/passes/visitor.rs | 5 ++ 14 files changed, 198 insertions(+), 31 deletions(-) create mode 100644 mir/src/ir/nodes/ops/cast.rs diff --git a/air/src/passes/translate_from_mir.rs b/air/src/passes/translate_from_mir.rs index 8851c4ffb..815cca86e 100644 --- a/air/src/passes/translate_from_mir.rs +++ b/air/src/passes/translate_from_mir.rs @@ -252,7 +252,7 @@ impl AirBuilder<'_> { )); }; - let ConstantValue::Felt(rhs_value) = constant_value else { + let ConstantValue::Scalar(rhs_value) = constant_value else { return Err(CompileError::SemanticAnalysis( SemanticAnalysisError::InvalidExpr( ast::InvalidExprError::NonConstantExponent(rhs.span()), @@ -267,7 +267,7 @@ impl AirBuilder<'_> { let value = match mir_value { MirValue::Constant(constant_value) => { - if let ConstantValue::Felt(felt) = constant_value { + if let ConstantValue::Scalar(felt) = constant_value { crate::ir::Value::Constant(*felt) } else { unreachable!() @@ -323,7 +323,7 @@ impl AirBuilder<'_> { let value = match mir_value { MirValue::Constant(constant_value) => { - if let ConstantValue::Felt(felt) = constant_value { + if let ConstantValue::Scalar(felt) = constant_value { crate::ir::Value::Constant(*felt) } else { unreachable!() diff --git a/mir/src/ir/node.rs b/mir/src/ir/node.rs index 90a800326..6974893cc 100644 --- a/mir/src/ir/node.rs +++ b/mir/src/ir/node.rs @@ -2,7 +2,7 @@ use std::ops::Deref; use miden_diagnostics::Spanned; -use crate::ir::{BackLink, Child, Link, Stale, Op, Owner, Parent, Root}; +use crate::ir::{BackLink, Child, Link, Op, Owner, Parent, Root, Stale}; /// All the nodes that can be in the MIR Graph /// Combines all [Root] and [Op] variants @@ -33,6 +33,7 @@ pub enum Node { BusOp(BackLink), Parameter(BackLink), Value(BackLink), + Cast(BackLink), None(Stale), } @@ -64,6 +65,7 @@ impl PartialEq for Node { (Node::BusOp(lhs), Node::BusOp(rhs)) => lhs.to_link() == rhs.to_link(), (Node::Parameter(lhs), Node::Parameter(rhs)) => lhs.to_link() == rhs.to_link(), (Node::Value(lhs), Node::Value(rhs)) => lhs.to_link() == rhs.to_link(), + (Node::Cast(lhs), Node::Cast(rhs)) => lhs.to_link() == rhs.to_link(), (Node::None(_), Node::None(_)) => true, _ => false, } @@ -92,6 +94,7 @@ impl std::hash::Hash for Node { Node::BusOp(b) => b.to_link().hash(state), Node::Parameter(p) => p.to_link().hash(state), Node::Value(v) => v.to_link().hash(state), + Node::Cast(c) => c.to_link().hash(state), Node::None(_) => (), } } @@ -119,6 +122,7 @@ impl Parent for Node { Node::BusOp(b) => b.children(), Node::Parameter(_p) => Link::default(), Node::Value(_v) => Link::default(), + Node::Cast(c) => c.children(), Node::None(_) => Link::default(), } } @@ -146,6 +150,7 @@ impl Child for Node { Node::BusOp(b) => b.get_parents(), Node::Parameter(p) => p.get_parents(), Node::Value(v) => v.get_parents(), + Node::Cast(c) => c.get_parents(), Node::None(_) => Vec::default(), } } @@ -169,6 +174,7 @@ impl Child for Node { Node::BusOp(b) => b.add_parent(parent), Node::Parameter(p) => p.add_parent(parent), Node::Value(v) => v.add_parent(parent), + Node::Cast(c) => c.add_parent(parent), Node::None(_) => (), } } @@ -192,6 +198,7 @@ impl Child for Node { Node::BusOp(b) => b.remove_parent(parent), Node::Parameter(p) => p.remove_parent(parent), Node::Value(v) => v.remove_parent(parent), + Node::Cast(c) => c.remove_parent(parent), Node::None(_) => (), } } @@ -220,6 +227,7 @@ impl Link { Op::BusOp(_) => Node::BusOp(BackLink::from(op_inner_val)), Op::Parameter(_) => Node::Parameter(BackLink::from(op_inner_val)), Op::Value(_) => Node::Value(BackLink::from(op_inner_val)), + Op::Cast(_) => Node::Cast(BackLink::from(op_inner_val)), Op::None(none) => Node::None(none.clone()), }; } else if let Some(root_inner_val) = self.as_root() { @@ -276,6 +284,7 @@ impl Link { Node::BusOp(_) => None, Node::Parameter(_) => None, Node::Value(_) => None, + Node::Cast(_) => None, Node::None(_) => None, } } @@ -301,6 +310,7 @@ impl Link { Node::BusOp(inner) => inner.to_link(), Node::Parameter(inner) => inner.to_link(), Node::Value(inner) => inner.to_link(), + Node::Cast(inner) => inner.to_link(), Node::None(_) => None, } } diff --git a/mir/src/ir/nodes/op.rs b/mir/src/ir/nodes/op.rs index bed3b76aa..deaebf081 100644 --- a/mir/src/ir/nodes/op.rs +++ b/mir/src/ir/nodes/op.rs @@ -7,8 +7,8 @@ use air_types::*; use miden_diagnostics::Spanned; use crate::ir::{ - Accessor, Add, BackLink, Boundary, BuilderHook, BusOp, Call, Child, ConstantValue, Enf, Exp, - Fold, For, If, Link, Matrix, MirValue, Mul, Node, Owner, Parameter, Parent, Singleton, + Accessor, Add, BackLink, Boundary, BuilderHook, BusOp, Call, Cast, Child, ConstantValue, Enf, + Exp, Fold, For, If, Link, Matrix, MirValue, Mul, Node, Owner, Parameter, Parent, Singleton, SpannedMirValue, Stale, Sub, Value, Vector, get_inner, get_inner_mut, }; @@ -34,6 +34,7 @@ pub enum Op { BusOp(BusOp), Parameter(Parameter), Value(Value), + Cast(Cast), None(Stale), } @@ -56,6 +57,7 @@ impl BuilderHook for Op { Op::BusOp(b) => b.finalize_hook(), Op::Parameter(p) => p.finalize_hook(), Op::Value(v) => v.finalize_hook(), + Op::Cast(c) => c.finalize_hook(), Op::None(_) => {}, } } @@ -87,6 +89,7 @@ impl Parent for Op { Op::BusOp(b) => b.children(), Op::Parameter(_) => Link::default(), Op::Value(_) => Link::default(), + Op::Cast(c) => c.children(), Op::None(_) => Link::default(), } } @@ -112,6 +115,7 @@ impl Child for Op { Op::BusOp(b) => b.get_parents(), Op::Parameter(p) => p.get_parents(), Op::Value(v) => v.get_parents(), + Op::Cast(c) => c.get_parents(), Op::None(_) => Default::default(), } } @@ -133,6 +137,7 @@ impl Child for Op { Op::BusOp(b) => b.add_parent(parent), Op::Parameter(p) => p.add_parent(parent), Op::Value(v) => v.add_parent(parent), + Op::Cast(c) => c.add_parent(parent), Op::None(_) => {}, } } @@ -154,6 +159,7 @@ impl Child for Op { Op::BusOp(b) => b.remove_parent(parent), Op::Parameter(p) => p.remove_parent(parent), Op::Value(v) => v.remove_parent(parent), + Op::Cast(c) => c.remove_parent(parent), Op::None(_) => {}, } } @@ -178,6 +184,7 @@ impl ScalarTypeMut for Op { Op::BusOp(_) => {}, Op::Parameter(p) => p.update_scalar_ty_unchecked(new_ty), Op::Value(v) => v.update_scalar_ty_unchecked(new_ty), + Op::Cast(_) => {}, Op::None(n) => n.update_scalar_ty_unchecked(new_ty), } } @@ -202,6 +209,7 @@ impl TypeMut for Op { Op::BusOp(_) => {}, Op::Parameter(p) => p.update_ty_unchecked(new_ty), Op::Value(v) => v.update_ty_unchecked(new_ty), + Op::Cast(_) => {}, Op::None(n) => n.update_ty_unchecked(new_ty), } } @@ -226,6 +234,7 @@ impl Typing for Op { Op::BusOp(_) => ty!(?), Op::Parameter(p) => p.ty(), Op::Value(v) => v.ty(), + Op::Cast(c) => c.ty(), Op::None(n) => n.ty(), } } @@ -252,6 +261,7 @@ impl Link { Op::BusOp(b) => format!("Op::BusOp@{}({:#?})", self.get_ptr(), b), Op::Parameter(p) => format!("Op::Parameter@{}({:#?})", self.get_ptr(), p), Op::Value(v) => format!("Op::Value@{}({:#?})", self.get_ptr(), v), + Op::Cast(c) => format!("Op::Cast@{}({:#?})", self.get_ptr(), c), Op::None(_) => "Op::None".to_string(), } } @@ -327,6 +337,9 @@ impl Link { Op::Value(value) => { value._node = Singleton::from(node.clone()); }, + Op::Cast(cast) => { + cast._node = Singleton::from(node.clone()); + }, Op::None(_) => {}, } } @@ -377,6 +390,9 @@ impl Link { }, Op::Parameter(_parameter) => {}, Op::Value(_value) => {}, + Op::Cast(cast) => { + cast._owner = Singleton::from(owner.clone()); + }, Op::None(_) => {}, } } @@ -482,6 +498,12 @@ impl Link { value._node = Singleton::from(node.clone()); node }, + Op::Cast(Cast { _node: Singleton(Some(link)), .. }) => link.clone(), + Op::Cast(cast) => { + let node: Link = Node::Cast(back).into(); + cast._node = Singleton::from(node.clone()); + node + }, Op::None(none) => Node::None(none.clone()).into(), } } @@ -576,6 +598,12 @@ impl Link { }, Op::Parameter(_) => None, Op::Value(_) => None, + Op::Cast(Cast { _owner: Singleton(Some(link)), .. }) => Some(link.clone()), + Op::Cast(cast) => { + let owner: Link = Owner::Cast(back).into(); + cast._owner = Singleton::from(owner.clone()); + cast._owner.0.clone() + }, Op::None(_) => None, } } @@ -835,13 +863,29 @@ impl Link { _ => None, }) } + /// Try getting the current [Op]'s inner [Cast]. + /// Returns None if the current [Op] is not a [Cast] or the Rc count is zero. + pub fn as_cast(&self) -> Option> { + get_inner(self.borrow(), |op| match op { + Op::Cast(inner) => Some(inner), + _ => None, + }) + } + /// Try getting the current [Op]'s inner [Cast], borrowing mutably. + /// Returns None if the current [Op] is not a [Cast] or the Rc count is zero. + pub fn as_cast_mut(&self) -> Option> { + get_inner_mut(self.borrow_mut(), |op| match op { + Op::Cast(inner) => Some(inner), + _ => None, + }) + } } impl From for Link { fn from(value: i64) -> Self { Op::Value(Value { value: SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(value as u64)), + value: MirValue::Constant(ConstantValue::Scalar(value as u64)), ..Default::default() }, ..Default::default() diff --git a/mir/src/ir/nodes/ops/cast.rs b/mir/src/ir/nodes/ops/cast.rs new file mode 100644 index 000000000..8663638d8 --- /dev/null +++ b/mir/src/ir/nodes/ops/cast.rs @@ -0,0 +1,51 @@ +use air_types::*; +use miden_diagnostics::{SourceSpan, Spanned}; + +use crate::ir::{BackLink, Builder, BuilderHook, Child, Link, Node, Op, Owner, Parent, Singleton}; + +#[derive(Default, Clone, PartialEq, Eq, Debug, Hash, Builder, Spanned)] +#[enum_wrapper(Op)] +pub struct Cast { + pub parents: Vec>, + /// The value being cast + pub value: Link, + ty: Option, + pub _node: Singleton, + pub _owner: Singleton, + #[span] + pub span: SourceSpan, +} + +impl BuilderHook for Cast {} + +impl Cast { + pub fn create(value: Link, ty: Option, span: SourceSpan) -> Link { + let cast = Self { value, ty, span, ..Default::default() }; + Link::new(Op::Cast(cast)) + } +} +impl Typing for Cast { + fn ty(&self) -> Option { + self.ty.ty() + } +} + +impl Parent for Cast { + type Child = Op; + fn children(&self) -> Link>> { + Link::new(vec![self.value.clone()]) + } +} + +impl Child for Cast { + type Parent = Owner; + fn get_parents(&self) -> Vec> { + self.parents.clone() + } + fn add_parent(&mut self, parent: Link) { + self.parents.push(parent.into()); + } + fn remove_parent(&mut self, parent: Link) { + self.parents.retain(|p| *p != parent.clone().into()); + } +} diff --git a/mir/src/ir/nodes/ops/mod.rs b/mir/src/ir/nodes/ops/mod.rs index 5ead181af..496ccb4ea 100644 --- a/mir/src/ir/nodes/ops/mod.rs +++ b/mir/src/ir/nodes/ops/mod.rs @@ -3,6 +3,7 @@ mod add; mod boundary; mod bus_op; mod call; +mod cast; mod enf; mod exp; mod fold; @@ -20,6 +21,7 @@ pub use add::Add; pub use boundary::Boundary; pub use bus_op::{BusOp, BusOpKind}; pub use call::Call; +pub use cast::Cast; pub use enf::Enf; pub use exp::Exp; pub use fold::{Fold, FoldOperator}; diff --git a/mir/src/ir/nodes/ops/value.rs b/mir/src/ir/nodes/ops/value.rs index c6366e5d8..18aa45247 100644 --- a/mir/src/ir/nodes/ops/value.rs +++ b/mir/src/ir/nodes/ops/value.rs @@ -47,7 +47,7 @@ impl From for Value { fn from(value: i64) -> Self { Self { value: SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(value as u64)), + value: MirValue::Constant(ConstantValue::Scalar(value as u64)), span: Default::default(), }, ..Default::default() @@ -124,11 +124,28 @@ impl BusAccess { #[derive(Debug, Eq, PartialEq, Clone, Hash)] pub enum ConstantValue { - Felt(u64), + Scalar(u64), Vector(Vec), Matrix(Vec>), } +impl Typing for ConstantValue { + fn ty(&self) -> Option { + match self { + ConstantValue::Scalar(_) => ty!(uint), + ConstantValue::Vector(v) => ty!(uint[v.len()]), + ConstantValue::Matrix(m) => { + let row_count = m.len(); + if row_count == 0 { + return ty!(uint[usize::MAX, usize::MAX]); + } + let col_count = m.iter().map(|r| r.len()).max().unwrap_or(usize::MAX); + ty!(uint[row_count, col_count]) + }, + } + } +} + /// [TraceAccess] is like SymbolAccess, but is used to describe an access to a specific trace /// column or columns. #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] @@ -271,7 +288,7 @@ impl PublicInputTableAccess { impl Default for SpannedMirValue { fn default() -> Self { Self { - value: MirValue::Constant(ConstantValue::Felt(0)), + value: MirValue::Constant(ConstantValue::Scalar(0)), span: Default::default(), } } diff --git a/mir/src/ir/owner.rs b/mir/src/ir/owner.rs index ec907f06a..44d2c8a9e 100644 --- a/mir/src/ir/owner.rs +++ b/mir/src/ir/owner.rs @@ -2,7 +2,7 @@ use std::ops::Deref; use miden_diagnostics::Spanned; -use crate::ir::{BackLink, Child, Link, Node, Stale, Op, Parent, Root}; +use crate::ir::{BackLink, Child, Link, Node, Op, Parent, Root, Stale}; /// The nodes that can own [Op] nodes /// The [Owner] enum does not own it's inner struct to avoid reference cycles, @@ -30,6 +30,7 @@ pub enum Owner { Enf(BackLink), For(BackLink), If(BackLink), + Cast(BackLink), None(Stale), } @@ -53,6 +54,7 @@ impl Parent for Owner { Owner::Matrix(m) => m.children(), Owner::Accessor(a) => a.children(), Owner::BusOp(b) => b.children(), + Owner::Cast(c) => c.children(), Owner::None(_) => Link::default(), } } @@ -78,6 +80,7 @@ impl Child for Owner { Owner::Matrix(m) => m.get_parents(), Owner::Accessor(a) => a.get_parents(), Owner::BusOp(b) => b.get_parents(), + Owner::Cast(c) => c.get_parents(), Owner::None(_) => Vec::default(), } } @@ -99,6 +102,7 @@ impl Child for Owner { Owner::Matrix(m) => m.add_parent(parent), Owner::Accessor(a) => a.add_parent(parent), Owner::BusOp(b) => b.add_parent(parent), + Owner::Cast(c) => c.add_parent(parent), Owner::None(_) => (), } } @@ -120,6 +124,7 @@ impl Child for Owner { Owner::Matrix(m) => m.remove_parent(parent), Owner::Accessor(a) => a.remove_parent(parent), Owner::BusOp(b) => b.remove_parent(parent), + Owner::Cast(c) => c.remove_parent(parent), Owner::None(_) => (), } } @@ -145,6 +150,7 @@ impl PartialEq for Owner { (Owner::Matrix(lhs), Owner::Matrix(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::Accessor(lhs), Owner::Accessor(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::BusOp(lhs), Owner::BusOp(rhs)) => lhs.to_link() == rhs.to_link(), + (Owner::Cast(lhs), Owner::Cast(rhs)) => lhs.to_link() == rhs.to_link(), (Owner::None(_), Owner::None(_)) => true, _ => false, } @@ -171,6 +177,7 @@ impl std::hash::Hash for Owner { Owner::Matrix(m) => m.to_link().hash(state), Owner::Accessor(a) => a.to_link().hash(state), Owner::BusOp(b) => b.to_link().hash(state), + Owner::Cast(c) => c.to_link().hash(state), Owner::None(s) => s.hash(state), } } @@ -199,6 +206,7 @@ impl Link { Op::BusOp(_) => Owner::BusOp(BackLink::from(op_inner_val)), Op::Parameter(_) => unreachable!(), Op::Value(_) => unreachable!(), + Op::Cast(_) => Owner::Cast(BackLink::from(op_inner_val)), Op::None(none) => Owner::None(none.clone()), }; } else if let Some(root_inner_val) = self.as_root() { @@ -243,6 +251,7 @@ impl Link { Owner::Enf(_) => None, Owner::For(_) => None, Owner::If(_) => None, + Owner::Cast(_) => None, Owner::None(_) => None, } } @@ -266,6 +275,7 @@ impl Link { Owner::Enf(back) => back.to_link(), Owner::For(back) => back.to_link(), Owner::If(back) => back.to_link(), + Owner::Cast(back) => back.to_link(), Owner::None(_) => None, } } @@ -301,6 +311,7 @@ impl BackLink { Owner::Enf(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::For(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::If(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), + Owner::Cast(back) => back.to_link().map(|l| l.get_ptr()).unwrap_or(0), Owner::None(_) => 0, }) .unwrap_or(0) diff --git a/mir/src/ir/quad_eval.rs b/mir/src/ir/quad_eval.rs index d5b0774e6..9b37e8cad 100644 --- a/mir/src/ir/quad_eval.rs +++ b/mir/src/ir/quad_eval.rs @@ -113,7 +113,7 @@ impl RandomInputs { }, Op::Value(v) => { match &v.value.value { - MirValue::Constant(ConstantValue::Felt(c)) => { + MirValue::Constant(ConstantValue::Scalar(c)) => { let felt = Felt::new(*c); Ok(const_quad_felt(felt)) }, @@ -220,6 +220,7 @@ impl RandomInputs { ); Err(CompileError::Failed) }, + Op::Cast(cast) => self.eval(cast.value.clone()), } } } diff --git a/mir/src/passes/constant_propagation.rs b/mir/src/passes/constant_propagation.rs index 139eeab65..f44ce5fec 100644 --- a/mir/src/passes/constant_propagation.rs +++ b/mir/src/passes/constant_propagation.rs @@ -88,7 +88,7 @@ impl ConstantPropagation<'_> { match (get_inner_const(&lhs), get_inner_const(&rhs)) { (Some(0), _) | (_, Some(0)) => Ok(Some(Value::create(SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(0)), + value: MirValue::Constant(ConstantValue::Scalar(0)), span: mul_ref.span, }))), (Some(1), _) => Ok(Some(rhs)), @@ -109,12 +109,12 @@ impl ConstantPropagation<'_> { if let Some(0) = get_inner_const(&lhs) { Ok(Some(Value::create(SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(0)), + value: MirValue::Constant(ConstantValue::Scalar(0)), span: exp_ref.span, }))) } else if let Some(0) = get_inner_const(&rhs) { Ok(Some(Value::create(SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(1)), + value: MirValue::Constant(ConstantValue::Scalar(1)), span: exp_ref.span, }))) } else { @@ -169,6 +169,7 @@ impl Visitor for ConstantPropagation<'_> { | Node::BusOp(_) | Node::Value(_) | Node::Accessor(_) + | Node::Cast(_) | Node::None(_) => Ok(None), Node::Function(_) | Node::Evaluator(_) | Node::Call(_) => { unreachable!( @@ -219,7 +220,7 @@ fn get_inner_const(value: &Link) -> Option { Op::Value(Value { value: SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(c)), + value: MirValue::Constant(ConstantValue::Scalar(c)), .. }, .. @@ -275,7 +276,7 @@ fn try_fold_const_binary_op( }; if let Some(folded) = folded { let new_value = Value::create(SpannedMirValue { - value: MirValue::Constant(crate::ir::ConstantValue::Felt(folded)), + value: MirValue::Constant(crate::ir::ConstantValue::Scalar(folded)), span, }); updated_binary_op = Some(new_value); diff --git a/mir/src/passes/mod.rs b/mir/src/passes/mod.rs index 864b5f0ea..dcb7406f0 100644 --- a/mir/src/passes/mod.rs +++ b/mir/src/passes/mod.rs @@ -5,6 +5,7 @@ mod unrolling; mod visitor; use std::{collections::HashMap, ops::Deref}; +use air_types::Typing; pub use constant_propagation::ConstantPropagation; pub use inlining::Inlining; use miden_diagnostics::Spanned; @@ -13,8 +14,8 @@ pub use unrolling::Unrolling; pub use visitor::Visitor; use crate::ir::{ - Accessor, Add, Boundary, BusOp, Call, Enf, Exp, Fold, For, If, Link, MatchArm, Matrix, Mul, - Node, Op, Owner, Parameter, Parent, Sub, Value, Vector, + Accessor, Add, Boundary, BusOp, Call, Cast, Enf, Exp, Fold, For, If, Link, MatchArm, Matrix, + Mul, Node, Op, Owner, Parameter, Parent, Sub, Value, Vector, }; /// Helper to duplicate a MIR node and its children recursively @@ -200,6 +201,12 @@ pub fn duplicate_node( new_param }, Op::Value(value) => Value::create(value.value.clone()), + Op::Cast(cast) => { + let value = cast.value.clone(); + let ty = cast.ty(); + let new_expr = duplicate_node(value, current_replace_map); + Cast::create(new_expr, ty, cast.span()) + }, Op::None(none) => Op::None(none.clone()).into(), } } @@ -447,6 +454,13 @@ pub fn duplicate_node_or_replace( let new_node = Value::create(value.value.clone()); current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); }, + Op::Cast(cast) => { + let value = cast.value.clone(); + let ty = cast.ty(); + let new_expr = current_replace_map.get(&value.get_ptr()).unwrap().1.clone(); + let new_node = Cast::create(new_expr, ty, cast.span()); + current_replace_map.insert(node.get_ptr(), (node.clone(), new_node)); + }, Op::None(_) => {}, } } diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 57277db79..7fc9492c6 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -13,10 +13,11 @@ use miden_diagnostics::{DiagnosticsHandler, Severity, SourceSpan, Span, Spanned} use crate::{ CompileError, ir::{ - Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, ConstantValue, - Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, Matrix, Mir, - MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, PublicInputTableAccess, Root, - SpannedMirValue, Stale, Sub, TraceAccess, TraceAccessBinding, Type, Value, Vector, + Accessor, Add, Boundary, Builder, Bus, BusAccess, BusOp, BusOpKind, Call, Cast, + ConstantValue, Enf, Evaluator, Exp, Fold, FoldOperator, For, Function, If, Link, MatchArm, + Matrix, Mir, MirValue, Mul, Op, Owner, Parameter, PublicInputAccess, + PublicInputTableAccess, Root, SpannedMirValue, Stale, Sub, TraceAccess, TraceAccessBinding, + Type, Value, Vector, }, passes::duplicate_node, }; @@ -885,8 +886,9 @@ impl<'a> MirBuilder<'a> { let _ = self.insert_enforce(node); let bool_x = duplicate_node(x, &mut Default::default()); // TODO: cast to a bool - //bool_x.update_ty(ty!(bool)); - Ok(bool_x) + let cast = + Cast::builder().value(bool_x).span(call.span()).ty(ty!(bool)).build(); + Ok(cast) }, other => unimplemented!("unhandled builtin: {}", other), } @@ -1124,7 +1126,7 @@ impl<'a> MirBuilder<'a> { sty: Option, ) -> Result, CompileError> { let value = SpannedMirValue { - value: MirValue::Constant(ConstantValue::Felt(c)), + value: MirValue::Constant(ConstantValue::Scalar(c)), span, }; let node = Value::builder().value(value).ty(ty!(sty)).build(); diff --git a/mir/src/passes/unrolling/match_optimizer.rs b/mir/src/passes/unrolling/match_optimizer.rs index d21ca9b65..e00117e92 100644 --- a/mir/src/passes/unrolling/match_optimizer.rs +++ b/mir/src/passes/unrolling/match_optimizer.rs @@ -227,7 +227,7 @@ impl<'a> MatchOptimizer<'a> { // representing `enf x = y` let zero_node = Value::create(SpannedMirValue { span: Default::default(), - value: MirValue::Constant(ConstantValue::Felt(0)), + value: MirValue::Constant(ConstantValue::Scalar(0)), }); // The following unwrap is safe as we always have at least one constraint above let new_node_with_sub_zero = Sub::create(cur_node.unwrap(), zero_node, span); diff --git a/mir/src/passes/unrolling/unrolling_first_pass.rs b/mir/src/passes/unrolling/unrolling_first_pass.rs index 96a53880c..49a39de7b 100644 --- a/mir/src/passes/unrolling/unrolling_first_pass.rs +++ b/mir/src/passes/unrolling/unrolling_first_pass.rs @@ -1,7 +1,7 @@ use std::{collections::HashMap, ops::Deref}; use air_parser::ast::AccessType; -use air_types::{Type, ty}; +use air_types::{Type, Typing, ty}; use miden_diagnostics::{DiagnosticsHandler, SourceSpan, Spanned}; use crate::{ @@ -87,7 +87,7 @@ fn unroll_constant_vector(constant_vector: &Vec, span: SourceSpan) -> Link< for val in constant_vector { let val = Value::create(SpannedMirValue { span, - value: MirValue::Constant(ConstantValue::Felt(*val)), + value: MirValue::Constant(ConstantValue::Scalar(*val)), }); vec.push(val); } @@ -102,7 +102,7 @@ fn unroll_constant_matrix(constant_matrix: &Vec>, span: SourceSpan) -> for val in row { let val = Value::create(SpannedMirValue { span, - value: MirValue::Constant(ConstantValue::Felt(*val)), + value: MirValue::Constant(ConstantValue::Scalar(*val)), }); res_row.push(val); } @@ -264,7 +264,7 @@ impl UnrollingFirstPass<'_> { let mir_value = value_ref.value.value.clone(); match &mir_value { MirValue::Constant(c) => match c { - ConstantValue::Felt(_) => {}, + ConstantValue::Scalar(_) => {}, ConstantValue::Vector(v) => { return Ok(Some(unroll_constant_vector(v, value_ref.span()))); }, @@ -578,6 +578,14 @@ impl UnrollingFirstPass<'_> { ) -> Result>, CompileError> { Ok(None) // Matrix are already unrolled, we have nothing to do } + + fn visit_cast_bis( + &mut self, + _graph: &mut Graph, + _cast: Link, + ) -> Result>, CompileError> { + Ok(None) + } } impl Visitor for UnrollingFirstPass<'_> { @@ -640,6 +648,7 @@ impl Visitor for UnrollingFirstPass<'_> { to_link_and(p.clone(), graph, |g, el| self.visit_parameter_bis(g, el)) }, Node::Value(v) => to_link_and(v.clone(), graph, |g, el| self.visit_value_bis(g, el)), + Node::Cast(c) => to_link_and(c.clone(), graph, |g, el| self.visit_cast_bis(g, el)), Node::None(_) => Ok(None), Node::Function(_) | Node::Evaluator(_) | Node::Call(_) => { unreachable!( diff --git a/mir/src/passes/visitor.rs b/mir/src/passes/visitor.rs index 77fabc1e2..c733f857b 100644 --- a/mir/src/passes/visitor.rs +++ b/mir/src/passes/visitor.rs @@ -62,6 +62,7 @@ pub trait Visitor { Node::Accessor(a) => self.visit_accessor(graph, a.clone().into()), Node::BusOp(b) => self.visit_bus_op(graph, b.clone().into()), Node::Parameter(p) => self.visit_parameter(graph, p.clone().into()), + Node::Cast(c) => self.visit_cast(graph, c.clone().into()), Node::Value(v) => self.visit_value(graph, v.clone().into()), Node::None(_) => Ok(()), } @@ -154,6 +155,10 @@ pub trait Visitor { ) -> Result<(), CompileError> { Ok(()) } + /// Visit a `Cast` node + fn visit_cast(&mut self, _graph: &mut Graph, _cast: Link) -> Result<(), CompileError> { + Ok(()) + } /// Visit a `Value` node fn visit_value(&mut self, _graph: &mut Graph, _value: Link) -> Result<(), CompileError> { Ok(()) From bd2df160dad8be1189a7648b663e245e46de6511 Mon Sep 17 00:00:00 2001 From: Thybault Alabarbe Date: Wed, 24 Sep 2025 12:42:50 +0200 Subject: [PATCH 42/42] refactor(types): rename assert_bool -> as_bool --- mir/src/passes/translate.rs | 2 +- mir/src/tests/typing.rs | 2 +- parser/src/ast/expression.rs | 14 +++++++------- parser/src/sema/semantic_analysis.rs | 6 +++--- parser/src/symbols.rs | 6 +++--- parser/src/transforms/constant_propagation.rs | 2 +- 6 files changed, 16 insertions(+), 16 deletions(-) diff --git a/mir/src/passes/translate.rs b/mir/src/passes/translate.rs index 7fc9492c6..02b3fc909 100644 --- a/mir/src/passes/translate.rs +++ b/mir/src/passes/translate.rs @@ -863,7 +863,7 @@ impl<'a> MirBuilder<'a> { .build(); Ok(node) }, - symbols::AssertBool => { + symbols::AsBool => { assert_eq!(call.args.len(), 1); let x = self.translate_expr(call.args.first().unwrap())?; // enf x^2 = x diff --git a/mir/src/tests/typing.rs b/mir/src/tests/typing.rs index 7d4ff3bbd..5b8c93b3e 100644 --- a/mir/src/tests/typing.rs +++ b/mir/src/tests/typing.rs @@ -18,7 +18,7 @@ fn test_typing() { } integrity_constraints { - let b2 = assert_bool(b); + let b2 = as_bool(b); let c = select(a, b2); enf c = 42; } diff --git a/parser/src/ast/expression.rs b/parser/src/ast/expression.rs index 8742fe818..8c145f89d 100644 --- a/parser/src/ast/expression.rs +++ b/parser/src/ast/expression.rs @@ -160,7 +160,7 @@ impl QualifiedIdentifier { if self.module.name() == "$builtin" { match self.item { NamespacedIdentifier::Function(id) => { - matches!(id.name(), symbols::Sum | symbols::Prod | symbols::AssertBool) + matches!(id.name(), symbols::Sum | symbols::Prod | symbols::AsBool) }, _ => false, } @@ -1427,7 +1427,7 @@ impl Call { match callee.name() { symbols::Sum => Self::sum(span, args), symbols::Prod => Self::prod(span, args), - symbols::AssertBool => Self::assert_bool(span, args), + symbols::AsBool => Self::as_bool(span, args), _ => Self { span, callee: ResolvableIdentifier::Unresolved(NamespacedIdentifier::Function(callee)), @@ -1457,13 +1457,13 @@ impl Call { Self::new_builtin(span, "prod", args, ty!(felt).unwrap()) } - /// Constructs a function call for `assert_bool`. - /// An `assert_bool(x)` is equivalent to an `enf x^2 = x plus a cast from felt to bool`. + /// Constructs a function call for `as_bool`. + /// An `as_bool(x)` is equivalent to an `enf x^2 = x plus a cast from felt to bool`. #[inline] - pub fn assert_bool(span: SourceSpan, args: Vec) -> Self { - //Self::new_builtin(span, "assert_bool", args, ty!(felt).unwrap()) + pub fn as_bool(span: SourceSpan, args: Vec) -> Self { + //Self::new_builtin(span, "as_bool", args, ty!(felt).unwrap()) let builtin_module = Identifier::new(SourceSpan::UNKNOWN, Symbol::intern("$builtin")); - let name = Identifier::new(span, Symbol::intern("assert_bool")); + let name = Identifier::new(span, Symbol::intern("as_bool")); let id = QualifiedIdentifier::new(builtin_module, NamespacedIdentifier::Function(name)); let callee = ResolvableIdentifier::Resolved(id); let ty = ty!(bool); diff --git a/parser/src/sema/semantic_analysis.rs b/parser/src/sema/semantic_analysis.rs index c3ab4120e..4258ccf15 100644 --- a/parser/src/sema/semantic_analysis.rs +++ b/parser/src/sema/semantic_analysis.rs @@ -1179,7 +1179,7 @@ impl SemanticAnalysis<'_> { }, // The known built-in cast functions - each takes a single argument, which // must be a subtype of the expected type - symbols::AssertBool => { + symbols::AsBool => { match call.args.as_slice() { [arg] => { match self.expr_binding_type(arg) { @@ -1987,8 +1987,8 @@ impl SemanticAnalysis<'_> { let folder_ty = FunctionType::Function(vec![ty!(felt[usize::MAX])], ty!(felt)); Ok(Span::new(qid.span(), BindingType::Function(folder_ty))) }, - symbols::AssertBool => { - // An `assert_bool(x)` is equivalent to an `enf x^2 = x and + symbols::AsBool => { + // An `as_bool(x)` is equivalent to an `enf x^2 = x and // a cast from felt to bool`. Ok(Span::new(qid.span(), BindingType::Function(fty!(fn(felt) -> bool)))) }, diff --git a/parser/src/symbols.rs b/parser/src/symbols.rs index d917728f1..982a25561 100644 --- a/parser/src/symbols.rs +++ b/parser/src/symbols.rs @@ -17,15 +17,15 @@ pub mod predefined { pub const Sum: Symbol = Symbol::new(2); /// The symbol `prod` pub const Prod: Symbol = Symbol::new(3); - /// The symbol `assert_bool` - pub const AssertBool: Symbol = Symbol::new(4); + /// The symbol `as_bool` + pub const AsBool: Symbol = Symbol::new(4); pub(super) const __SYMBOLS: &[(Symbol, &str)] = &[ (Main, "$main"), (Builtin, "$builtin"), (Sum, "sum"), (Prod, "prod"), - (AssertBool, "assert_bool"), + (AsBool, "as_bool"), ]; } diff --git a/parser/src/transforms/constant_propagation.rs b/parser/src/transforms/constant_propagation.rs index 8d79f6f64..cf7e3b900 100644 --- a/parser/src/transforms/constant_propagation.rs +++ b/parser/src/transforms/constant_propagation.rs @@ -402,7 +402,7 @@ impl VisitMut for ConstantPropagation<'_> { } } }, - symbols::AssertBool => { + symbols::AsBool => { assert_eq!(call.args.len(), 1); match &call.args[0] { // If the assertion is a constant 0 or 1, it's valid