diff --git a/huia-compiler/src/context/mod.rs b/huia-compiler/src/context/mod.rs index c3668bc..9464adb 100644 --- a/huia-compiler/src/context/mod.rs +++ b/huia-compiler/src/context/mod.rs @@ -33,13 +33,12 @@ impl Context { } // If there's already an anonymous trait with the same dependencies then return a reference to that instead. - for (i, ty) in self.types.iter().enumerate() { - if ty.is_trait() + if let Some(idx) = self.find_type_index(|ty| { + ty.is_trait() && ty.is_anonymous() && ty.get_dependencies() == Some(requirements.iter().collect()) - { - return i.into(); - } + }) { + return idx; } let idx = self.types.len(); @@ -58,7 +57,7 @@ impl Context { } pub fn current_file(&self) -> &str { - self.strings.get(self.current_file).unwrap() + self.strings.resolve(self.current_file).unwrap() } /// Define a function in the context and return a reference to it. @@ -90,11 +89,7 @@ impl Context { requirements: Vec, location: Location, ) -> TyIdx { - if self - .types - .iter() - .any(|t| t.name() == Some(name) && !t.is_unresolved()) - { + if self.find_type_index(|ty| ty.name() == Some(name)).is_some() { let message = format!("Type {} redefined", self.get_string(name).unwrap()); self.compile_error(&message, location.clone(), ErrorKind::TypeRedefined); } @@ -113,9 +108,8 @@ impl Context { location: Location, ) -> TyIdx { if self - .types - .iter() - .any(|t| t.name() == Some(name) && !t.is_unresolved()) + .find_type_index(|ty| ty.name() == Some(name) && !ty.is_unresolved()) + .is_some() { let message = format!("Type {} redefined", self.get_string(name).unwrap()); self.compile_error(&message, location.clone(), ErrorKind::TypeRedefined) @@ -146,9 +140,30 @@ impl Context { } #[cfg(test)] - pub fn find_type(&mut self, name: &str) -> Option<&Ty> { - let name = self.constant_string(name); - self.types.iter().find(|ty| ty.name() == Some(name)) + pub fn find_type_by_name(&self, name: &str) -> Option<&Ty> { + match self.strings.get(name) { + Some(idx) => self.find_type(|ty| ty.name() == Some(idx)), + None => None, + } + } + + /// Return the first type which matches the predicate. + pub fn find_type bool>(&self, predicate: F) -> Option<&Ty> { + match self.find_type_index(predicate) { + Some(idx) => self.types.get(usize::from(idx)), + None => None, + } + } + + /// Return the index of the first type that matches the predicate + pub fn find_type_index bool>(&self, predicate: F) -> Option { + for (idx, ty) in self.types.iter().enumerate() { + if predicate(ty) { + return Some(idx.into()); + } + } + + None } /// Retrieve a specific block by it's index. @@ -174,7 +189,7 @@ impl Context { } pub fn get_string(&self, idx: StringIdx) -> Option<&str> { - self.strings.get(idx) + self.strings.resolve(idx) } /// Retrieve a specific type by it's index. @@ -245,7 +260,7 @@ impl Context { /// Convert an AST `InputLocation` into a compiler `Location`. pub fn location(&self, location: &InputLocation) -> Location { - let path = self.strings.get(self.current_file).unwrap(); + let path = self.strings.resolve(self.current_file).unwrap(); Location::new(location.clone(), path) } diff --git a/huia-compiler/src/ir/builder.rs b/huia-compiler/src/ir/builder.rs index 93ab868..5953585 100644 --- a/huia-compiler/src/ir/builder.rs +++ b/huia-compiler/src/ir/builder.rs @@ -709,7 +709,7 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_constant()); - assert!(context.find_type("Huia.Native.Atom").is_some()); + assert!(context.find_type_by_name("Huia.Native.Atom").is_some()); } #[test] @@ -721,7 +721,7 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_constant()); - assert!(context.find_type("Huia.Native.Boolean").is_some()); + assert!(context.find_type_by_name("Huia.Native.Boolean").is_some()); } #[test] @@ -733,7 +733,7 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_constant()); - assert!(context.find_type("Huia.Native.Float").is_some()); + assert!(context.find_type_by_name("Huia.Native.Float").is_some()); } #[test] @@ -745,7 +745,7 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_constant()); - assert!(context.find_type("Huia.Native.String").is_some()); + assert!(context.find_type_by_name("Huia.Native.String").is_some()); } #[test] @@ -757,7 +757,7 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_type_reference()); - assert!(context.find_type("MartyMcFly").is_some()); + assert!(context.find_type_by_name("MartyMcFly").is_some()); } #[test] @@ -769,8 +769,8 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_constant()); - assert!(context.find_type("Huia.Native.Array").is_some()); - assert!(context.find_type("Huia.Native.Integer").is_some()); + assert!(context.find_type_by_name("Huia.Native.Array").is_some()); + assert!(context.find_type_by_name("Huia.Native.Integer").is_some()); } #[test] @@ -782,9 +782,9 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_constant()); - assert!(context.find_type("Huia.Native.Map").is_some()); - assert!(context.find_type("Huia.Native.Atom").is_some()); - assert!(context.find_type("Huia.Native.Integer").is_some()); + assert!(context.find_type_by_name("Huia.Native.Map").is_some()); + assert!(context.find_type_by_name("Huia.Native.Atom").is_some()); + assert!(context.find_type_by_name("Huia.Native.Integer").is_some()); } #[test] @@ -831,7 +831,7 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_infix()); - assert!(context.find_type("Huia.Native.Integer").is_some()); + assert!(context.find_type_by_name("Huia.Native.Integer").is_some()); } #[test] @@ -843,8 +843,8 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_constructor()); - assert!(context.find_type("Delorean").is_some()); - assert!(context.find_type("Huia.Native.Integer").is_some()); + assert!(context.find_type_by_name("Delorean").is_some()); + assert!(context.find_type_by_name("Huia.Native.Integer").is_some()); } #[test] @@ -856,7 +856,7 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_unary()); - assert!(context.find_type("Huia.Native.Integer").is_some()); + assert!(context.find_type_by_name("Huia.Native.Integer").is_some()); } #[test] @@ -870,7 +870,7 @@ mod test { } let call = builder.pop_ir().unwrap(); assert!(call.is_call()); - assert!(context.find_type("Huia.Native.Integer").is_some()); + assert!(context.find_type_by_name("Huia.Native.Integer").is_some()); } #[test] @@ -882,7 +882,7 @@ mod test { builder.build(term, &mut context); let ir = builder.pop_ir().unwrap(); assert!(ir.is_set_local()); - assert!(context.find_type("Huia.Native.Integer").is_some()); + assert!(context.find_type_by_name("Huia.Native.Integer").is_some()); } #[test] @@ -892,8 +892,8 @@ mod test { let mut context = Context::test(); builder.push_block(context.unknown_type(Loc::test())); builder.build(term, &mut context); - assert!(context.find_type("Delorean").is_some()); - assert!(context.find_type("Integer").is_some()); + assert!(context.find_type_by_name("Delorean").is_some()); + assert!(context.find_type_by_name("Integer").is_some()); } #[test] @@ -903,7 +903,7 @@ mod test { let mut context = Context::test(); builder.push_block(context.unknown_type(Loc::test())); builder.build(term, &mut context); - assert!(context.find_type("TimeMachine").is_some()); + assert!(context.find_type_by_name("TimeMachine").is_some()); } #[test] @@ -921,8 +921,8 @@ mod test { let mut context = Context::test(); builder.push_block(context.unknown_type(Loc::test())); builder.build(term, &mut context); - assert!(context.find_type("Delorean").is_some()); - assert!(context.find_type("TimeMachine").is_some()); + assert!(context.find_type_by_name("Delorean").is_some()); + assert!(context.find_type_by_name("TimeMachine").is_some()); } #[test] diff --git a/huia-compiler/src/stable.rs b/huia-compiler/src/stable.rs index af49c42..83aae59 100644 --- a/huia-compiler/src/stable.rs +++ b/huia-compiler/src/stable.rs @@ -4,10 +4,18 @@ use string_interner::{StringInterner, Sym}; pub struct StringTable(StringInterner); impl StringTable { - pub fn get(&self, idx: StringIdx) -> Option<&str> { + pub fn resolve(&self, idx: StringIdx) -> Option<&str> { self.0.resolve(idx.into()) } + #[allow(dead_code)] + pub fn get + AsRef>(&self, value: T) -> Option { + match self.0.get(value) { + Some(syn) => Some(syn.into()), + None => None, + } + } + pub fn intern + AsRef>(&mut self, value: T) -> StringIdx { self.0.get_or_intern(value).into() } @@ -60,9 +68,9 @@ mod test { } #[test] - fn test_get() { + fn test_resolve() { let mut stable = StringTable::default(); let idx = stable.intern("Marty McFly"); - assert_eq!(stable.get(idx).unwrap(), "Marty McFly"); + assert_eq!(stable.resolve(idx).unwrap(), "Marty McFly"); } }