Tidy up type-finding code.

This commit is contained in:
James Harton 2019-04-09 09:47:06 +12:00
parent 1d3815e281
commit 1aeeda93fe
3 changed files with 66 additions and 43 deletions

View file

@ -33,13 +33,12 @@ impl Context {
} }
// If there's already an anonymous trait with the same dependencies then return a reference to that instead. // If there's already an anonymous trait with the same dependencies then return a reference to that instead.
for (i, ty) in self.types.iter().enumerate() { if let Some(idx) = self.find_type_index(|ty| {
if ty.is_trait() ty.is_trait()
&& ty.is_anonymous() && ty.is_anonymous()
&& ty.get_dependencies() == Some(requirements.iter().collect()) && ty.get_dependencies() == Some(requirements.iter().collect())
{ }) {
return i.into(); return idx;
}
} }
let idx = self.types.len(); let idx = self.types.len();
@ -58,7 +57,7 @@ impl Context {
} }
pub fn current_file(&self) -> &str { pub fn current_file(&self) -> &str {
self.strings.get(self.current_file).unwrap() self.strings.resolve(self.current_file).unwrap()
} }
/// Define a function in the context and return a reference to it. /// Define a function in the context and return a reference to it.
@ -90,11 +89,7 @@ impl Context {
requirements: Vec<TyIdx>, requirements: Vec<TyIdx>,
location: Location, location: Location,
) -> TyIdx { ) -> TyIdx {
if self if self.find_type_index(|ty| ty.name() == Some(name)).is_some() {
.types
.iter()
.any(|t| t.name() == Some(name) && !t.is_unresolved())
{
let message = format!("Type {} redefined", self.get_string(name).unwrap()); let message = format!("Type {} redefined", self.get_string(name).unwrap());
self.compile_error(&message, location.clone(), ErrorKind::TypeRedefined); self.compile_error(&message, location.clone(), ErrorKind::TypeRedefined);
} }
@ -113,9 +108,8 @@ impl Context {
location: Location, location: Location,
) -> TyIdx { ) -> TyIdx {
if self if self
.types .find_type_index(|ty| ty.name() == Some(name) && !ty.is_unresolved())
.iter() .is_some()
.any(|t| t.name() == Some(name) && !t.is_unresolved())
{ {
let message = format!("Type {} redefined", self.get_string(name).unwrap()); let message = format!("Type {} redefined", self.get_string(name).unwrap());
self.compile_error(&message, location.clone(), ErrorKind::TypeRedefined) self.compile_error(&message, location.clone(), ErrorKind::TypeRedefined)
@ -146,9 +140,30 @@ impl Context {
} }
#[cfg(test)] #[cfg(test)]
pub fn find_type(&mut self, name: &str) -> Option<&Ty> { pub fn find_type_by_name(&self, name: &str) -> Option<&Ty> {
let name = self.constant_string(name); match self.strings.get(name) {
self.types.iter().find(|ty| ty.name() == Some(name)) Some(idx) => self.find_type(|ty| ty.name() == Some(idx)),
None => None,
}
}
/// Return the first type which matches the predicate.
pub fn find_type<F: Fn(&Ty) -> bool>(&self, predicate: F) -> Option<&Ty> {
match self.find_type_index(predicate) {
Some(idx) => self.types.get(usize::from(idx)),
None => None,
}
}
/// Return the index of the first type that matches the predicate
pub fn find_type_index<F: Fn(&Ty) -> bool>(&self, predicate: F) -> Option<TyIdx> {
for (idx, ty) in self.types.iter().enumerate() {
if predicate(ty) {
return Some(idx.into());
}
}
None
} }
/// Retrieve a specific block by it's index. /// Retrieve a specific block by it's index.
@ -174,7 +189,7 @@ impl Context {
} }
pub fn get_string(&self, idx: StringIdx) -> Option<&str> { pub fn get_string(&self, idx: StringIdx) -> Option<&str> {
self.strings.get(idx) self.strings.resolve(idx)
} }
/// Retrieve a specific type by it's index. /// Retrieve a specific type by it's index.
@ -245,7 +260,7 @@ impl Context {
/// Convert an AST `InputLocation` into a compiler `Location`. /// Convert an AST `InputLocation` into a compiler `Location`.
pub fn location(&self, location: &InputLocation) -> Location { pub fn location(&self, location: &InputLocation) -> Location {
let path = self.strings.get(self.current_file).unwrap(); let path = self.strings.resolve(self.current_file).unwrap();
Location::new(location.clone(), path) Location::new(location.clone(), path)
} }

View file

@ -709,7 +709,7 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_constant()); assert!(ir.is_constant());
assert!(context.find_type("Huia.Native.Atom").is_some()); assert!(context.find_type_by_name("Huia.Native.Atom").is_some());
} }
#[test] #[test]
@ -721,7 +721,7 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_constant()); assert!(ir.is_constant());
assert!(context.find_type("Huia.Native.Boolean").is_some()); assert!(context.find_type_by_name("Huia.Native.Boolean").is_some());
} }
#[test] #[test]
@ -733,7 +733,7 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_constant()); assert!(ir.is_constant());
assert!(context.find_type("Huia.Native.Float").is_some()); assert!(context.find_type_by_name("Huia.Native.Float").is_some());
} }
#[test] #[test]
@ -745,7 +745,7 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_constant()); assert!(ir.is_constant());
assert!(context.find_type("Huia.Native.String").is_some()); assert!(context.find_type_by_name("Huia.Native.String").is_some());
} }
#[test] #[test]
@ -757,7 +757,7 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_type_reference()); assert!(ir.is_type_reference());
assert!(context.find_type("MartyMcFly").is_some()); assert!(context.find_type_by_name("MartyMcFly").is_some());
} }
#[test] #[test]
@ -769,8 +769,8 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_constant()); assert!(ir.is_constant());
assert!(context.find_type("Huia.Native.Array").is_some()); assert!(context.find_type_by_name("Huia.Native.Array").is_some());
assert!(context.find_type("Huia.Native.Integer").is_some()); assert!(context.find_type_by_name("Huia.Native.Integer").is_some());
} }
#[test] #[test]
@ -782,9 +782,9 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_constant()); assert!(ir.is_constant());
assert!(context.find_type("Huia.Native.Map").is_some()); assert!(context.find_type_by_name("Huia.Native.Map").is_some());
assert!(context.find_type("Huia.Native.Atom").is_some()); assert!(context.find_type_by_name("Huia.Native.Atom").is_some());
assert!(context.find_type("Huia.Native.Integer").is_some()); assert!(context.find_type_by_name("Huia.Native.Integer").is_some());
} }
#[test] #[test]
@ -831,7 +831,7 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_infix()); assert!(ir.is_infix());
assert!(context.find_type("Huia.Native.Integer").is_some()); assert!(context.find_type_by_name("Huia.Native.Integer").is_some());
} }
#[test] #[test]
@ -843,8 +843,8 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_constructor()); assert!(ir.is_constructor());
assert!(context.find_type("Delorean").is_some()); assert!(context.find_type_by_name("Delorean").is_some());
assert!(context.find_type("Huia.Native.Integer").is_some()); assert!(context.find_type_by_name("Huia.Native.Integer").is_some());
} }
#[test] #[test]
@ -856,7 +856,7 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_unary()); assert!(ir.is_unary());
assert!(context.find_type("Huia.Native.Integer").is_some()); assert!(context.find_type_by_name("Huia.Native.Integer").is_some());
} }
#[test] #[test]
@ -870,7 +870,7 @@ mod test {
} }
let call = builder.pop_ir().unwrap(); let call = builder.pop_ir().unwrap();
assert!(call.is_call()); assert!(call.is_call());
assert!(context.find_type("Huia.Native.Integer").is_some()); assert!(context.find_type_by_name("Huia.Native.Integer").is_some());
} }
#[test] #[test]
@ -882,7 +882,7 @@ mod test {
builder.build(term, &mut context); builder.build(term, &mut context);
let ir = builder.pop_ir().unwrap(); let ir = builder.pop_ir().unwrap();
assert!(ir.is_set_local()); assert!(ir.is_set_local());
assert!(context.find_type("Huia.Native.Integer").is_some()); assert!(context.find_type_by_name("Huia.Native.Integer").is_some());
} }
#[test] #[test]
@ -892,8 +892,8 @@ mod test {
let mut context = Context::test(); let mut context = Context::test();
builder.push_block(context.unknown_type(Loc::test())); builder.push_block(context.unknown_type(Loc::test()));
builder.build(term, &mut context); builder.build(term, &mut context);
assert!(context.find_type("Delorean").is_some()); assert!(context.find_type_by_name("Delorean").is_some());
assert!(context.find_type("Integer").is_some()); assert!(context.find_type_by_name("Integer").is_some());
} }
#[test] #[test]
@ -903,7 +903,7 @@ mod test {
let mut context = Context::test(); let mut context = Context::test();
builder.push_block(context.unknown_type(Loc::test())); builder.push_block(context.unknown_type(Loc::test()));
builder.build(term, &mut context); builder.build(term, &mut context);
assert!(context.find_type("TimeMachine").is_some()); assert!(context.find_type_by_name("TimeMachine").is_some());
} }
#[test] #[test]
@ -921,8 +921,8 @@ mod test {
let mut context = Context::test(); let mut context = Context::test();
builder.push_block(context.unknown_type(Loc::test())); builder.push_block(context.unknown_type(Loc::test()));
builder.build(term, &mut context); builder.build(term, &mut context);
assert!(context.find_type("Delorean").is_some()); assert!(context.find_type_by_name("Delorean").is_some());
assert!(context.find_type("TimeMachine").is_some()); assert!(context.find_type_by_name("TimeMachine").is_some());
} }
#[test] #[test]

View file

@ -4,10 +4,18 @@ use string_interner::{StringInterner, Sym};
pub struct StringTable(StringInterner<Sym>); pub struct StringTable(StringInterner<Sym>);
impl StringTable { impl StringTable {
pub fn get(&self, idx: StringIdx) -> Option<&str> { pub fn resolve(&self, idx: StringIdx) -> Option<&str> {
self.0.resolve(idx.into()) self.0.resolve(idx.into())
} }
#[allow(dead_code)]
pub fn get<T: Into<String> + AsRef<str>>(&self, value: T) -> Option<StringIdx> {
match self.0.get(value) {
Some(syn) => Some(syn.into()),
None => None,
}
}
pub fn intern<T: Into<String> + AsRef<str>>(&mut self, value: T) -> StringIdx { pub fn intern<T: Into<String> + AsRef<str>>(&mut self, value: T) -> StringIdx {
self.0.get_or_intern(value).into() self.0.get_or_intern(value).into()
} }
@ -60,9 +68,9 @@ mod test {
} }
#[test] #[test]
fn test_get() { fn test_resolve() {
let mut stable = StringTable::default(); let mut stable = StringTable::default();
let idx = stable.intern("Marty McFly"); let idx = stable.intern("Marty McFly");
assert_eq!(stable.get(idx).unwrap(), "Marty McFly"); assert_eq!(stable.resolve(idx).unwrap(), "Marty McFly");
} }
} }