diff --git a/pkg/lang/set.go b/pkg/lang/set.go index d468896..f3d363c 100644 --- a/pkg/lang/set.go +++ b/pkg/lang/set.go @@ -9,22 +9,22 @@ type Set struct { meta IPersistentMap hash, hasheq uint32 - vals []interface{} + hashMap IPersistentMap } -type PersistentHashSet = Set // hack until we have a proper persistent hash set +type PersistentHashSet = Set -func CreatePersistentTreeSet(keys ISeq) interface{} { +func CreatePersistentTreeSet(keys ISeq) any { // TODO: implement return NewSet(seqToSlice(keys)...) } -func CreatePersistentTreeSetWithComparator(comparator IFn, keys ISeq) interface{} { +func CreatePersistentTreeSetWithComparator(comparator IFn, keys ISeq) any { // TODO: implement return NewSet(seqToSlice(keys)...) } -func NewSet(vals ...interface{}) *Set { +func NewSet(vals ...any) *Set { set, err := NewSet2(vals...) if err != nil { panic(err) @@ -32,19 +32,16 @@ func NewSet(vals ...interface{}) *Set { return set } -func NewSet2(vals ...interface{}) (*Set, error) { - // check for duplicates +func NewSet2(vals ...any) (*Set, error) { + set := &Set{ + hashMap: NewPersistentHashMap(), + } for i := 0; i < len(vals); i++ { - for j := i + 1; j < len(vals); j++ { - if Equiv(vals[i], vals[j]) { - return nil, NewIllegalArgumentError(fmt.Sprintf("duplicate key: %v", vals[i])) - } - } + val := vals[i] + set.hashMap = set.hashMap.Assoc(val, true).(IPersistentMap) } - return &Set{ - vals: vals, - }, nil + return set, nil } var ( @@ -55,16 +52,15 @@ var ( emptySet = NewSet() ) -func (s *Set) Get(key interface{}) interface{} { - for _, v := range s.vals { - if Equiv(v, key) { - return v - } +func (s *Set) Get(key any) any { + val := s.hashMap.ValAt(key) + if val == true { + return key } return nil } -func (s *Set) Invoke(args ...interface{}) interface{} { +func (s *Set) Invoke(args ...any) any { if len(args) != 1 { panic(fmt.Errorf("set apply expects 1 argument, got %d", len(args))) } @@ -72,40 +68,36 @@ func (s *Set) Invoke(args ...interface{}) interface{} { return s.Get(args[0]) } -func (s *Set) ApplyTo(args ISeq) interface{} { +func (s *Set) ApplyTo(args ISeq) any { return s.Invoke(seqToSlice(args)...) } -func (s *Set) Cons(v interface{}) Conser { +func (s *Set) Cons(v any) Conser { if s.Contains(v) { return s } - return NewSet(append(s.vals, v)...) + return &Set{ + meta: s.meta, + hashMap: s.hashMap.Assoc(v, true).(IPersistentMap), + } } -func (s *Set) Disjoin(v interface{}) IPersistentSet { - for i, val := range s.vals { - if Equiv(val, v) { - newItems := make([]interface{}, len(s.vals)-1) - copy(newItems, s.vals[:i]) - copy(newItems[i:], s.vals[i+1:]) - return NewSet(newItems...) - } +func (s *Set) Disjoin(v any) IPersistentSet { + if !s.Contains(v) { + return s + } + return &Set{ + meta: s.meta, + hashMap: s.hashMap.Without(v).(IPersistentMap), } - return s } -func (s *Set) Contains(v interface{}) bool { - for _, val := range s.vals { - if Equiv(val, v) { - return true - } - } - return false +func (s *Set) Contains(v any) bool { + return s.hashMap.ContainsKey(v) } func (s *Set) Count() int { - return len(s.vals) + return s.hashMap.Count() } func (s *Set) xxx_counted() {} @@ -122,7 +114,7 @@ func (s *Set) String() string { return PrintString(s) } -func (s *Set) Equals(v2 interface{}) bool { +func (s *Set) Equals(v2 any) bool { if s == v2 { return true } @@ -143,10 +135,10 @@ func (s *Set) Equals(v2 interface{}) bool { } func (s *Set) Seq() ISeq { - if s.Count() == 0 { + if s.hashMap.Count() == 0 { return nil } - return NewSliceSeq(s.vals) + return NewMapKeySeq(Seq(s.hashMap)) } func (s *Set) Equiv(o any) bool { @@ -165,15 +157,14 @@ func (s *Set) Meta() IPersistentMap { return s.meta } -func (s *Set) WithMeta(meta IPersistentMap) interface{} { +func (s *Set) WithMeta(meta IPersistentMap) any { if meta == s.meta { return s } - return &Set{ - meta: meta, - vals: s.vals, - } + cpy := *s + cpy.meta = meta + return &cpy } func (s *Set) AsTransient() ITransientCollection { @@ -185,7 +176,7 @@ type TransientSet struct { *Set } -func (s *TransientSet) Conj(v interface{}) Conjer { +func (s *TransientSet) Conj(v any) Conjer { return &TransientSet{Set: s.Set.Cons(v).(*Set)} } diff --git a/pkg/reader/testdata/reader/set00.glj b/pkg/reader/testdata/reader/set00.glj index a6b83a0..327b712 100644 --- a/pkg/reader/testdata/reader/set00.glj +++ b/pkg/reader/testdata/reader/set00.glj @@ -1 +1 @@ -#{:a :b :c} +#{:a} diff --git a/pkg/reader/testdata/reader/set00.out b/pkg/reader/testdata/reader/set00.out index a6b83a0..327b712 100644 --- a/pkg/reader/testdata/reader/set00.out +++ b/pkg/reader/testdata/reader/set00.out @@ -1 +1 @@ -#{:a :b :c} +#{:a}