Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions hugr-core/src/export.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
use crate::Visibility;
use crate::extension::ExtensionRegistry;
use crate::hugr::internal::HugrInternals;
use crate::types::{FuncTypeBase, PolyFuncTypeBase, TypeRowLike};
use crate::types::{FuncTypeBase, GeneralSum, PolyFuncTypeBase, TypeRowLike};
use crate::{
Direction, Hugr, HugrView, IncomingPort, Node, NodeIndex as _, Port,
extension::{ExtensionId, OpDef, SignatureFunc},
Expand Down Expand Up @@ -889,7 +889,7 @@ impl<'a> Context<'a> {
);
self.make_term(table::Term::List(parts))
}
SumType::General { rows } => {
SumType::General(GeneralSum { rows, .. }) => {
let parts = self.bump.alloc_slice_fill_iter(
rows.iter()
.map(|row| table::SeqPart::Item(self.export_term(row, None))),
Expand Down
6 changes: 3 additions & 3 deletions hugr-core/src/extension/resolution/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use super::{ExtensionCollectionError, WeakExtensionRegistry};
use crate::Node;
use crate::extension::{ExtensionRegistry, ExtensionSet};
use crate::ops::{DataflowOpTrait, OpType, Value};
use crate::types::{FuncValueType, Signature, SumType, Term, TypeRow};
use crate::types::{FuncValueType, GeneralSum, Signature, SumType, Term, TypeRow};

/// Collects every extension used to define the types in an operation.
///
Expand Down Expand Up @@ -205,7 +205,7 @@ pub(crate) fn collect_term_exts(
collect_term_exts(&f.input, used_extensions, missing_extensions);
collect_term_exts(&f.output, used_extensions, missing_extensions);
}
Term::RuntimeSum(SumType::General { rows }) => {
Term::RuntimeSum(SumType::General(GeneralSum { rows, .. })) => {
for row in rows {
collect_term_exts(row, used_extensions, missing_extensions);
}
Expand Down Expand Up @@ -271,7 +271,7 @@ fn collect_value_exts(
collect_term_exts(&typ, used_extensions, missing_extensions);
}
Value::Sum(s) => {
if let SumType::General { rows } = &s.sum_type {
if let SumType::General(GeneralSum { rows, .. }) = &s.sum_type {
for row in rows {
collect_term_exts(row, used_extensions, missing_extensions);
}
Expand Down
8 changes: 5 additions & 3 deletions hugr-core/src/extension/resolution/types_mut.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@ use super::types::collect_term_exts;
use super::{ExtensionResolutionError, WeakExtensionRegistry};
use crate::extension::ExtensionSet;
use crate::ops::{OpType, Value};
use crate::types::{CustomType, FuncValueType, Signature, SumType, Term, Type, TypeRow, TypeRowRV};
use crate::types::{
CustomType, FuncValueType, GeneralSum, Signature, SumType, Term, Type, TypeRow, TypeRowRV,
};
use crate::{Extension, Node};

/// Replace the dangling extension pointer in the [`CustomType`]s inside an
Expand Down Expand Up @@ -240,7 +242,7 @@ pub(super) fn resolve_term_exts(
Term::RuntimeFunction(f) => {
resolve_func_type_exts(node, &mut *f, extensions, used_extensions)?;
}
Term::RuntimeSum(SumType::General { rows }) => {
Term::RuntimeSum(SumType::General(GeneralSum { rows, .. })) => {
for row in rows.iter_mut() {
resolve_typerow_rv_exts(node, row, extensions, used_extensions)?;
}
Expand Down Expand Up @@ -302,7 +304,7 @@ pub(super) fn resolve_value_exts(
}
}
Value::Sum(s) => {
if let SumType::General { rows } = &mut s.sum_type {
if let SumType::General(GeneralSum { rows, .. }) = &mut s.sum_type {
for row in rows.iter_mut() {
resolve_typerow_rv_exts(node, row, extensions, used_extensions)?;
}
Expand Down
128 changes: 64 additions & 64 deletions hugr-core/src/types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -177,8 +177,34 @@ pub enum SumType {
#[allow(missing_docs)]
Unit { size: u8 },
/// General case of a Sum type.
#[allow(missing_docs)]
General { rows: Vec<TypeRowRV> },
General(GeneralSum),
}

/// The general case of a [SumType].
///
/// Can store any sum type, including those that can be more efficiently
/// represented as a [SumType::Unit].
#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub struct GeneralSum {
/// The types of the variants of the sum. Each variant is a row of types.
pub rows: Vec<TypeRowRV>,
#[serde(skip)]
bound: TypeBound,
}

impl GeneralSum {
/// Initialize a new `GeneralSum` with the given rows.
pub fn new(rows: Vec<TypeRowRV>) -> Self {
let bound = if rows
.iter()
.all(|row| check_term_type(row, &Term::new_list_type(TypeBound::Copyable)).is_ok())
{
TypeBound::Copyable
} else {
TypeBound::Linear
};
Self { rows, bound }
}
}

impl std::hash::Hash for SumType {
Expand All @@ -205,7 +231,7 @@ impl std::fmt::Display for SumType {
SumType::Unit { size } => {
display_list_with_separator(itertools::repeat_n("[]", *size as usize), f, "+")
}
SumType::General { rows } => match rows.len() {
SumType::General(GeneralSum { rows, .. }) => match rows.len() {
1 if rows[0].is_empty() => write!(f, "Unit"),
2 if rows[0].is_empty() && rows[1].is_empty() => write!(f, "Bool"),
_ => display_list_with_separator(rows.iter(), f, "+"),
Expand All @@ -225,7 +251,7 @@ impl SumType {
if u8::try_from(len).is_ok() && variants.iter().all(TypeRowRV::is_empty) {
Self::new_unary(len as u8)
} else {
Self::General { rows: variants }
Self::General(GeneralSum::new(variants))
}
}

Expand All @@ -250,7 +276,7 @@ impl SumType {
pub fn get_variant(&self, tag: usize) -> Option<&TypeRowRV> {
match self {
SumType::Unit { size } if tag < (*size as usize) => Some(TypeRowRV::EMPTY_REF),
SumType::General { rows } => rows.get(tag),
SumType::General(GeneralSum { rows, .. }) => rows.get(tag),
_ => None,
}
}
Expand All @@ -260,7 +286,7 @@ impl SumType {
pub fn num_variants(&self) -> usize {
match self {
SumType::Unit { size } => *size as usize,
SumType::General { rows } => rows.len(),
SumType::General(general) => general.rows.len(),
}
}

Expand All @@ -269,7 +295,7 @@ impl SumType {
pub fn as_tuple(&self) -> Option<&TypeRowRV> {
match self {
SumType::Unit { size } if *size == 1 => Some(TypeRowRV::EMPTY_REF),
SumType::General { rows } if rows.len() == 1 => Some(&rows[0]),
SumType::General(general) if general.rows.len() == 1 => Some(&general.rows[0]),
_ => None,
}
}
Expand All @@ -280,7 +306,9 @@ impl SumType {
pub fn as_option(&self) -> Option<&TypeRowRV> {
match self {
SumType::Unit { size } if *size == 2 => Some(TypeRowRV::EMPTY_REF),
SumType::General { rows } if rows.len() == 2 && rows[0].is_empty() => Some(&rows[1]),
SumType::General(GeneralSum { rows, .. }) if rows.len() == 2 && rows[0].is_empty() => {
Some(&rows[1])
}
_ => None,
}
}
Expand All @@ -291,23 +319,14 @@ impl SumType {
SumType::Unit { size } => {
Either::Left(itertools::repeat_n(TypeRowRV::EMPTY_REF, *size as usize))
}
SumType::General { rows } => Either::Right(rows.iter()),
SumType::General(general) => Either::Right(general.rows.iter()),
}
}

fn bound(&self) -> TypeBound {
const fn bound(&self) -> TypeBound {
match self {
SumType::Unit { .. } => TypeBound::Copyable,
SumType::General { rows } => {
if rows
.iter()
.all(|t| check_term_type(t, &Term::new_list_type(TypeBound::Copyable)).is_ok())
{
TypeBound::Copyable
} else {
TypeBound::Linear
}
}
SumType::General(GeneralSum { bound, .. }) => *bound,
}
}
}
Expand All @@ -316,7 +335,13 @@ impl Transformable for SumType {
fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
match self {
SumType::Unit { .. } => Ok(false),
SumType::General { rows } => rows.transform(tr),
SumType::General(general) => {
let changed = general.rows.transform(tr)?;
if changed {
*general = GeneralSum::new(std::mem::take(&mut general.rows));
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this need a comment to say, recompute the bound?

}
Ok(changed)
}
}
}
}
Expand All @@ -325,7 +350,7 @@ impl From<SumType> for Type {
fn from(sum: SumType) -> Self {
match sum {
SumType::Unit { size } => Type::new_unit_sum(size),
SumType::General { rows } => Type::new_sum(rows),
SumType::General(GeneralSum { rows, .. }) => Type::new_sum(rows),
}
}
}
Expand Down Expand Up @@ -360,23 +385,17 @@ impl From<SumType> for Type {
/// let func_type: Type = Type::new_function(Signature::new_endo([]));
/// assert_eq!(func_type.least_upper_bound(), TypeBound::Copyable);
/// ```
pub struct Type(Term, TypeBound);
pub struct Type(Term);

impl Type {
/// An empty `TypeRow` or `TypeRowRV`. Provided here for convenience
pub const EMPTY_TYPEROW: TypeRow = TypeRow::new();
/// Unit type (empty tuple).
pub const UNIT: Self = Self(
Term::RuntimeSum(SumType::Unit { size: 1 }),
TypeBound::Copyable,
);
pub const UNIT: Self = Self(Term::RuntimeSum(SumType::Unit { size: 1 }));

/// Initialize a new function type.
pub fn new_function(fun_ty: impl Into<FuncValueType>) -> Self {
Self(
Term::RuntimeFunction(Box::new(fun_ty.into())),
TypeBound::Copyable,
)
Self(Term::RuntimeFunction(Box::new(fun_ty.into())))
}

/// Initialize a new tuple type by providing the elements.
Expand All @@ -396,26 +415,21 @@ impl Type {
R: Into<TypeRowRV>,
{
let st = SumType::new(variants);
let b = st.bound();
Self(Term::RuntimeSum(st), b)
Self(Term::RuntimeSum(st))
}

/// Initialize a new custom type.
// TODO remove? Extensions/TypeDefs should just provide `Type` directly
#[must_use]
pub const fn new_extension(opaque: CustomType) -> Self {
let bound = opaque.bound();
Self(Term::RuntimeExtension(opaque), bound)
Self(Term::RuntimeExtension(opaque))
}

/// New `UnitSum` with empty Tuple variants
#[must_use]
pub const fn new_unit_sum(size: u8) -> Self {
// should be the only way to avoid going through SumType::new
Self(
Term::RuntimeSum(SumType::new_unary(size)),
TypeBound::Copyable,
)
Self(Term::RuntimeSum(SumType::new_unary(size)))
}

/// New use (occurrence) of the type variable with specified index.
Expand All @@ -424,13 +438,13 @@ impl Type {
/// than required for the use.
#[must_use]
pub fn new_var_use(idx: usize, bound: TypeBound) -> Self {
Self(Term::new_var_use(idx, bound), bound)
Self(Term::new_var_use(idx, bound))
}

/// Report the least upper [`TypeBound`]
#[inline(always)]
pub const fn least_upper_bound(&self) -> TypeBound {
self.1
self.0.least_upper_bound().unwrap()
}

/// Report if the type is copyable - i.e.the least upper bound of the type
Expand All @@ -448,14 +462,6 @@ impl Type {
/// [TypeDef]: crate::extension::TypeDef
pub(crate) fn validate(&self, var_decls: &[TypeParam]) -> Result<(), SignatureError> {
self.0.validate(var_decls)?;
// ALAN even this should be only a debug-assert really:
// we have no unchecked access from outside crate::types
// so it must be a bug in our caching logic if this is wrong:
check_term_type(&self.0, &self.1.into())?;
debug_assert!(
self.1 == TypeBound::Copyable
|| check_term_type(&self.0, &TypeBound::Copyable.into()).is_err()
);
Ok(())
}

Expand All @@ -464,9 +470,7 @@ impl Type {
/// Always produces exactly one type, but may narrow the bound (from
/// [TypeBound::Linear] to [TypeBound::Copyable]).
fn substitute(&self, s: &Substitution) -> Self {
let t = self.0.substitute(s);
let b = t.least_upper_bound().unwrap(); // Recompute.
Self(t, b)
Self(self.0.substitute(s))
}

/// Returns a registry with the concrete extensions used by this type.
Expand All @@ -489,11 +493,7 @@ impl Type {

impl Transformable for Type {
fn transform<T: TypeTransformer>(&mut self, tr: &T) -> Result<bool, T::Err> {
let res = self.0.transform(tr)?;
if res {
self.1 = self.0.least_upper_bound().unwrap()
}
Ok(res)
self.0.transform(tr)
}
}

Expand All @@ -509,13 +509,13 @@ impl TryFrom<Term> for Type {
type Error = TermTypeError;

fn try_from(t: Term) -> Result<Self, TermTypeError> {
match t.least_upper_bound() {
Some(b) => Ok(Self(t, b)),
None => Err(TermTypeError::TypeMismatch {
term: Box::new(t),
type_: Box::new(TypeBound::Linear.into()),
}),
if t.is_runtime_type() {
return Ok(Self(t));
}
Err(TermTypeError::TypeMismatch {
term: Box::new(t),
type_: Box::new(TypeBound::Linear.into()),
})
}
}

Expand Down Expand Up @@ -701,7 +701,7 @@ pub(crate) mod test {
let empty_rows = vec![TypeRowRV::new(); 3];
let sum_unary = SumType::new_unary(3);
assert_eq!(empty_rows, sum_unary.variants().cloned().collect_vec());
let sum_general = SumType::General { rows: empty_rows };
let sum_general = SumType::General(GeneralSum::new(empty_rows));
assert_eq!(sum_general, sum_unary);

let mut hasher_general = std::hash::DefaultHasher::new();
Expand Down
Loading
Loading