From b81b3abf42627703c2896275696a77da0bf831d6 Mon Sep 17 00:00:00 2001 From: Ruben Menke Date: Tue, 15 Aug 2023 08:59:12 +0200 Subject: [PATCH 1/6] made diskcache work, add still too slow --- dev/company_names.py | 63 +++++++++++--------------- pyproject.toml | 12 +++-- simstring/database/disk.py | 88 +++++++++++++++++++++++++++++++++++++ tests/database/test_disk.py | 63 ++++++++++++++++++++++++++ 4 files changed, 185 insertions(+), 41 deletions(-) create mode 100644 simstring/database/disk.py create mode 100644 tests/database/test_disk.py diff --git a/dev/company_names.py b/dev/company_names.py index 06a47b7..c3d5394 100644 --- a/dev/company_names.py +++ b/dev/company_names.py @@ -1,56 +1,45 @@ # coding: utf-8 - -import os, sys - -import numpy as np - from simstring.feature_extractor.character_ngram import CharacterNgramFeatureExtractor -from simstring.measure.cosine import ( - CosineMeasure, -) # , OverlapMeasure, LeftOverlapMeasure +from simstring.measure.cosine import CosineMeasure +from simstring.measure.overlap import OverlapMeasure, LeftOverlapMeasure -# from simstring.database.mongo import MongoDatabase from simstring.database.dict import DictDatabase +from simstring.database.disk import DiskDatabase from simstring.searcher import Searcher +from tqdm import tqdm from pyinstrument import Profiler profiler = Profiler() -def output_similar_strings_of_each_line(path, measure): +def output_similar_strings_of_each_line(path, measures, db_cls): strings = [] with open(path, "r") as lines: for line in lines: - strings.append(line.rstrip("\r\n")) - - db = DictDatabase(CharacterNgramFeatureExtractor(2)) - for string in strings: - db.add(string) + strings.append(line.rstrip("\r\n").strip().lower()) + + db = make_db(db_cls, strings) - # db.save("companies.db") + for measure in measures: + searcher = Searcher(db, measure) + profiler.start() - # dbl = DictDatabase.load("companies.db") + for string in strings: + result = searcher.search(string, 0.8) - searcher = Searcher(db, measure) - profiler.start() + profiler.stop() + print(result) + print(db_cls.__name__, measure.__class__.__name__) + profiler.print() - for string in strings: - result = searcher.search(string, 0.8) - # result = [str(np.round(x[0], 5)) + ' ' + x[1] for x in searcher.ranked_search(string, 0.8)] - # print("\t".join([string, ",".join(result)])) - - profiler.stop() - print(result) - profiler.print() - # profiler.open_in_browser() - - -measure = CosineMeasure() -output_similar_strings_of_each_line("dev/data/company_names.txt", measure) - -# measure = OverlapMeasure() -# output_similar_strings_of_each_line("dev/data/company_names.txt", measure) +def make_db(db_cls, strings): + db = db_cls(CharacterNgramFeatureExtractor(2)) + for string in tqdm(strings): + db.add(string) + return db -# measure = LeftOverlapMeasure() -# output_similar_strings_of_each_line("./data/company_names.txt", measure) +if __name__ =="__main__": + measures = [CosineMeasure(), OverlapMeasure(), LeftOverlapMeasure()] + for db_cls in [DictDatabase,DiskDatabase]: + output_similar_strings_of_each_line("dev/data/company_names.txt", measures, db_cls) diff --git a/pyproject.toml b/pyproject.toml index 2df8d80..78ab317 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", ] -dependencies = [] +dependencies = ["diskcache"] dynamic = ["version"] [project.urls] @@ -45,13 +45,17 @@ mypy-args = [ "--check-untyped-defs", "--install-types" ] +exclude = [ + "simstring/database/disk.py", + "simstring/database/base.py", +] [tool.hatch.envs.test] dependencies = [ "pytest", "pytest-cov", "build", - "cython" + "cython==3.0.0" ] [tool.hatch.envs.default.scripts] @@ -61,7 +65,7 @@ no-cov = "cov --no-cov {args}" build = "python -m build" [[tool.hatch.envs.test.matrix]] -python = [ "38", "39", "310", "311"] +python = ["38","39", "310", "311"] [tool.coverage.run] branch = true @@ -94,7 +98,7 @@ serve = "cd simstring && mkdocs serve --dev-addr localhost:8000" [tool.hatch.envs.benchmark] dependencies = [ - "pyinstrument", "benchmarker" , "numpy" + "pyinstrument", "benchmarker" , "numpy", "tqdm" ] [[tool.hatch.envs.benchmark.matrix]] python = [ "38", "39", "310", "311"] diff --git a/simstring/database/disk.py b/simstring/database/disk.py new file mode 100644 index 0000000..7aeb099 --- /dev/null +++ b/simstring/database/disk.py @@ -0,0 +1,88 @@ + +from typing import List, Set, Dict, Union +from .base import BaseDatabase +from simstring.feature_extractor.character_ngram import CharacterNgramFeatureExtractor +from simstring.feature_extractor.word_ngram import WordNgramFeatureExtractor + +from io import BufferedWriter +import diskcache as dc +from functools import lru_cache + +import os + +FeatureExtractor = Union[ + CharacterNgramFeatureExtractor, WordNgramFeatureExtractor + ] + +class DiskDatabase(BaseDatabase): + def __init__( + self, + feature_extractor: FeatureExtractor, + path:str= 'tmp' + ): + self.feature_extractor = feature_extractor + self.feature_set_size_to_string_map: dc.Cache = dc.Cache(os.path.join(path,'feature_set_size_to_string_map')) + self.feature_set_size_and_feature_to_string_map: dc.Cache = dc.Cache(os.path.join(path,'feature_set_size_and_feature_to_string_map')) + self._min_feature_size = 9999999 + self._max_feature_size = 0 + self.path = path + + @staticmethod + def _make_key(size: int, feature: str) -> str: + return f"{size}-{feature}" + + def add_feature_set_size_and_feature_to_string_map(self, size, feature, string)-> None: + key = self._make_key(size,feature) + if key in self.feature_set_size_and_feature_to_string_map: + d = self.feature_set_size_and_feature_to_string_map[key] + if string in d: + return + else: + d = set() + d.add(string) + self.feature_set_size_and_feature_to_string_map[key] = d + + def get_feature_set_size_and_feature_to_string_map(self, size: int, feature: str + ) -> Set[str]: + try: + return self.feature_set_size_and_feature_to_string_map[self._make_key(size,feature)] + except KeyError: + return set() + + def add(self, string: str) -> None: + features = self.feature_extractor.features(string) + size = len(features) + + if size not in self.feature_set_size_to_string_map: + self.feature_set_size_to_string_map[size] = set() + else: + size_to_string_map = self.feature_set_size_to_string_map[size] + size_to_string_map.add(string) + self.feature_set_size_to_string_map[size] = size_to_string_map + + + self._min_feature_size = min(self._min_feature_size, size) + self._max_feature_size = max(self._max_feature_size, size) + + for feature in features: + self.add_feature_set_size_and_feature_to_string_map(size, feature, string) + + def all(self) -> List[str]: + strings = [] + for k in self.feature_set_size_to_string_map.iterkeys(): + strings.extend(self.feature_set_size_to_string_map[k]) + return strings + + def lookup_strings_by_feature_set_size_and_feature( + self, size: int, feature: str + ) -> Set[str]: + return self.get_feature_set_size_and_feature_to_string_map(size,feature) + + def min_feature_size(self) -> int: + return self._min_feature_size + + def max_feature_size(self) -> int: + return self._max_feature_size + + + diff --git a/tests/database/test_disk.py b/tests/database/test_disk.py new file mode 100644 index 0000000..c42a736 --- /dev/null +++ b/tests/database/test_disk.py @@ -0,0 +1,63 @@ +# -*- coding:utf-8 -*- + +from unittest import TestCase +from simstring.database.disk import DiskDatabase +from simstring.feature_extractor.character_ngram import CharacterNgramFeatureExtractor +import pickle +import os + + +class TestDict(TestCase): + strings = ["a", "ab", "abc", "abcd", "abcde"] + + def setUp(self): + self.db = DiskDatabase(CharacterNgramFeatureExtractor(2)) + for string in self.strings: + self.db.add(string) + + def test_strings(self): + self.assertEqual(self.db.strings, self.strings) + + # def test_min_feature_size(self): + # self.assertEqual(self.db.min_feature_size(), min(map(lambda x: len(x) + 1, self.strings))) + + # def test_max_feature_size(self): + # self.assertEqual(self.db.max_feature_size(), max(map(lambda x: len(x) + 1, self.strings))) + + def test_lookup_strings_by_feature_set_size_and_feature(self): + self.assertEqual( + self.db.lookup_strings_by_feature_set_size_and_feature(4, "ab_1"), + set(["abc"]), + ) + self.assertEqual( + self.db.lookup_strings_by_feature_set_size_and_feature(3, "ab_1"), + set(["ab"]), + ) + self.assertEqual( + self.db.lookup_strings_by_feature_set_size_and_feature(2, "ab_1"), set([]) + ) + + def test_load_from_folder(self): + + with open("test.pkl", "wb") as f: + pickle.dump(self.db, f) + + + with open("test.pkl", "rb") as f: + new = pickle.load(f) + + self.assertEqual(self.db._min_feature_size, new._min_feature_size) + self.assertEqual(self.db._max_feature_size, new._max_feature_size) + self.assertEqual( + self.db.feature_extractor.__class__, new.feature_extractor.__class__ + ) + self.assertEqual(self.db.feature_extractor.n, new.feature_extractor.n) + self.assertEqual( + set(self.db.feature_set_size_to_string_map.iterkeys()), set(new.feature_set_size_to_string_map.iterkeys()) + ) + self.assertEqual( + set(self.db.feature_set_size_and_feature_to_string_map.iterkeys()), + set(new.feature_set_size_and_feature_to_string_map.iterkeys()), + ) + + os.remove("test.pkl") From e617dc5c760daa2443f62882facc7b8570ef5b4c Mon Sep 17 00:00:00 2001 From: Ruben Menke Date: Tue, 15 Aug 2023 11:19:19 +0200 Subject: [PATCH 2/6] wip, faster add --- simstring/database/disk.py | 19 ++++++++++++++++--- tests/database/test_dict.py | 2 +- tests/database/test_disk.py | 2 +- 3 files changed, 18 insertions(+), 5 deletions(-) diff --git a/simstring/database/disk.py b/simstring/database/disk.py index 7aeb099..4920a73 100644 --- a/simstring/database/disk.py +++ b/simstring/database/disk.py @@ -49,7 +49,22 @@ def get_feature_set_size_and_feature_to_string_map(self, size: int, feature: str except KeyError: return set() + def commit(self): + pass + def add(self, string: str) -> None: + features, size = self._process_string(string) + + for feature in features: + self.add_feature_set_size_and_feature_to_string_map(size, feature, string) + + def fast_add(self, string: str) -> None: + features, size = self._process_string(string) + + for feature in features: + self.add_feature_set_size_and_feature_to_string_map(size, feature, string) + + def _process_string(self, string:str): features = self.feature_extractor.features(string) size = len(features) @@ -63,9 +78,7 @@ def add(self, string: str) -> None: self._min_feature_size = min(self._min_feature_size, size) self._max_feature_size = max(self._max_feature_size, size) - - for feature in features: - self.add_feature_set_size_and_feature_to_string_map(size, feature, string) + return features,size def all(self) -> List[str]: strings = [] diff --git a/tests/database/test_dict.py b/tests/database/test_dict.py index 7c85234..d8d5240 100644 --- a/tests/database/test_dict.py +++ b/tests/database/test_dict.py @@ -16,7 +16,7 @@ def setUp(self): self.db.add(string) def test_strings(self): - self.assertEqual(self.db.strings, self.strings) + self.assertEqual(self.all(), self.strings) # def test_min_feature_size(self): # self.assertEqual(self.db.min_feature_size(), min(map(lambda x: len(x) + 1, self.strings))) diff --git a/tests/database/test_disk.py b/tests/database/test_disk.py index c42a736..b38d0af 100644 --- a/tests/database/test_disk.py +++ b/tests/database/test_disk.py @@ -16,7 +16,7 @@ def setUp(self): self.db.add(string) def test_strings(self): - self.assertEqual(self.db.strings, self.strings) + self.assertEqual(sorted(self.db.all()), sorted(self.strings)) # def test_min_feature_size(self): # self.assertEqual(self.db.min_feature_size(), min(map(lambda x: len(x) + 1, self.strings))) From 09dba7cf8dc2317517addc065a004c77f89ded69 Mon Sep 17 00:00:00 2001 From: Ruben Menke Date: Tue, 15 Aug 2023 13:45:26 +0200 Subject: [PATCH 3/6] improved benchmarking --- .gitignore | 4 +++- dev/company_names.py | 13 ++++++++++--- tests/database/test_dict.py | 2 +- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/.gitignore b/.gitignore index cf56c1a..e42741d 100644 --- a/.gitignore +++ b/.gitignore @@ -12,4 +12,6 @@ dev/data/geo_address.csv dev/geo_matching.py *.so .coverage -simstring/site/ \ No newline at end of file +simstring/site/ +tmp*/ +addresses.csv \ No newline at end of file diff --git a/dev/company_names.py b/dev/company_names.py index c3d5394..f63542f 100644 --- a/dev/company_names.py +++ b/dev/company_names.py @@ -19,7 +19,7 @@ def output_similar_strings_of_each_line(path, measures, db_cls): for line in lines: strings.append(line.rstrip("\r\n").strip().lower()) - db = make_db(db_cls, strings) + db = make_db(db_cls, strings[:10_000]) for measure in measures: searcher = Searcher(db, measure) @@ -40,6 +40,13 @@ def make_db(db_cls, strings): return db if __name__ =="__main__": - measures = [CosineMeasure(), OverlapMeasure(), LeftOverlapMeasure()] + file = "dev/data/company_names.txt" + # file = "dev/data/unabridged_dictionary.txt" + # file = "dev/data/addresses.csv" + # measures = [CosineMeasure(), OverlapMeasure(), LeftOverlapMeasure()] + measures = [CosineMeasure()] for db_cls in [DictDatabase,DiskDatabase]: - output_similar_strings_of_each_line("dev/data/company_names.txt", measures, db_cls) + output_similar_strings_of_each_line(file, measures, db_cls) + + # for db_cls in [DictDatabase,DiskDatabase]: + # output_similar_strings_of_each_line("dev/data/unabridged_dictionary2.txt", measures, db_cls) diff --git a/tests/database/test_dict.py b/tests/database/test_dict.py index d8d5240..0109682 100644 --- a/tests/database/test_dict.py +++ b/tests/database/test_dict.py @@ -16,7 +16,7 @@ def setUp(self): self.db.add(string) def test_strings(self): - self.assertEqual(self.all(), self.strings) + self.assertEqual(sorted(self.db.all()), sorted(self.strings)) # def test_min_feature_size(self): # self.assertEqual(self.db.min_feature_size(), min(map(lambda x: len(x) + 1, self.strings))) From 83982be15a1cc0a951228f7f15e5897974d1096b Mon Sep 17 00:00:00 2001 From: Ruben Menke Date: Wed, 16 Aug 2023 08:31:44 +0200 Subject: [PATCH 4/6] clean up testing --- pyproject.toml | 2 +- tests/database/test_dict.py | 9 +++++---- tests/database/test_disk.py | 16 +++++++++++----- 3 files changed, 17 insertions(+), 10 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 78ab317..7bbce9c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", ] -dependencies = ["diskcache"] +dependencies = ["diskcache==5.6.1"] dynamic = ["version"] [project.urls] diff --git a/tests/database/test_dict.py b/tests/database/test_dict.py index 0109682..4c94bb5 100644 --- a/tests/database/test_dict.py +++ b/tests/database/test_dict.py @@ -18,11 +18,12 @@ def setUp(self): def test_strings(self): self.assertEqual(sorted(self.db.all()), sorted(self.strings)) - # def test_min_feature_size(self): - # self.assertEqual(self.db.min_feature_size(), min(map(lambda x: len(x) + 1, self.strings))) - # def test_max_feature_size(self): - # self.assertEqual(self.db.max_feature_size(), max(map(lambda x: len(x) + 1, self.strings))) + def test_min_feature_size(self): + self.assertEqual(self.db.min_feature_size(), 2) + + def test_max_feature_size(self): + self.assertEqual(self.db.max_feature_size(), 6) def test_lookup_strings_by_feature_set_size_and_feature(self): self.assertEqual( diff --git a/tests/database/test_disk.py b/tests/database/test_disk.py index b38d0af..849cc38 100644 --- a/tests/database/test_disk.py +++ b/tests/database/test_disk.py @@ -5,24 +5,30 @@ from simstring.feature_extractor.character_ngram import CharacterNgramFeatureExtractor import pickle import os +import shutil class TestDict(TestCase): strings = ["a", "ab", "abc", "abcd", "abcde"] def setUp(self): - self.db = DiskDatabase(CharacterNgramFeatureExtractor(2)) + self.db = DiskDatabase(CharacterNgramFeatureExtractor(2), path="tmp_db_for_tests") for string in self.strings: self.db.add(string) + + def tearDown(self) -> None: + shutil.rmtree(self.db.path) + return super().tearDown() + def test_strings(self): self.assertEqual(sorted(self.db.all()), sorted(self.strings)) - # def test_min_feature_size(self): - # self.assertEqual(self.db.min_feature_size(), min(map(lambda x: len(x) + 1, self.strings))) + def test_min_feature_size(self): + self.assertEqual(self.db.min_feature_size(), 2) - # def test_max_feature_size(self): - # self.assertEqual(self.db.max_feature_size(), max(map(lambda x: len(x) + 1, self.strings))) + def test_max_feature_size(self): + self.assertEqual(self.db.max_feature_size(), 6) def test_lookup_strings_by_feature_set_size_and_feature(self): self.assertEqual( From 835ea689b502d4451e6901f73a9d2a829a1a0e4e Mon Sep 17 00:00:00 2001 From: Ruben Menke Date: Wed, 16 Aug 2023 08:36:34 +0200 Subject: [PATCH 5/6] fix tests --- simstring/database/disk.py | 7 ++++--- tests/database/test_disk.py | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/simstring/database/disk.py b/simstring/database/disk.py index 4920a73..f993348 100644 --- a/simstring/database/disk.py +++ b/simstring/database/disk.py @@ -69,11 +69,12 @@ def _process_string(self, string:str): size = len(features) if size not in self.feature_set_size_to_string_map: - self.feature_set_size_to_string_map[size] = set() + size_to_string_map = set() else: size_to_string_map = self.feature_set_size_to_string_map[size] - size_to_string_map.add(string) - self.feature_set_size_to_string_map[size] = size_to_string_map + + size_to_string_map.add(string) + self.feature_set_size_to_string_map[size] = size_to_string_map self._min_feature_size = min(self._min_feature_size, size) diff --git a/tests/database/test_disk.py b/tests/database/test_disk.py index 849cc38..c163093 100644 --- a/tests/database/test_disk.py +++ b/tests/database/test_disk.py @@ -8,7 +8,7 @@ import shutil -class TestDict(TestCase): +class TestDisk(TestCase): strings = ["a", "ab", "abc", "abcd", "abcde"] def setUp(self): From 8524ff0bc87c78a8364c5c87233f9d0709728938 Mon Sep 17 00:00:00 2001 From: Ruben Menke Date: Thu, 17 Aug 2023 12:25:26 +0200 Subject: [PATCH 6/6] added redis support --- dev/company_names.py | 15 +++++++-- pyproject.toml | 3 +- simstring/database/base.py | 3 ++ simstring/database/redis.py | 65 ++++++++++++++++++++++++++++++++++++ tests/database/test_redis.py | 63 ++++++++++++++++++++++++++++++++++ 5 files changed, 145 insertions(+), 4 deletions(-) create mode 100644 simstring/database/redis.py create mode 100644 tests/database/test_redis.py diff --git a/dev/company_names.py b/dev/company_names.py index f63542f..5f8691c 100644 --- a/dev/company_names.py +++ b/dev/company_names.py @@ -5,6 +5,7 @@ from simstring.database.dict import DictDatabase from simstring.database.disk import DiskDatabase +from simstring.database.redis import RedisDatabase from simstring.searcher import Searcher from tqdm import tqdm @@ -19,7 +20,7 @@ def output_similar_strings_of_each_line(path, measures, db_cls): for line in lines: strings.append(line.rstrip("\r\n").strip().lower()) - db = make_db(db_cls, strings[:10_000]) + db = make_db(db_cls, strings) for measure in measures: searcher = Searcher(db, measure) @@ -35,8 +36,14 @@ def output_similar_strings_of_each_line(path, measures, db_cls): def make_db(db_cls, strings): db = db_cls(CharacterNgramFeatureExtractor(2)) + i = 0 for string in tqdm(strings): db.add(string) + i += 1 + if (i % 10000) == 0: + db.commit() + i = 0 + db.commit() return db if __name__ =="__main__": @@ -45,8 +52,10 @@ def make_db(db_cls, strings): # file = "dev/data/addresses.csv" # measures = [CosineMeasure(), OverlapMeasure(), LeftOverlapMeasure()] measures = [CosineMeasure()] - for db_cls in [DictDatabase,DiskDatabase]: - output_similar_strings_of_each_line(file, measures, db_cls) + dbs = [DictDatabase,DiskDatabase, RedisDatabase] + # dbs = [DiskDatabase] + for db_cls in dbs: + output_similar_strings_of_each_line(file, measures, db_cls) # for db_cls in [DictDatabase,DiskDatabase]: # output_similar_strings_of_each_line("dev/data/unabridged_dictionary2.txt", measures, db_cls) diff --git a/pyproject.toml b/pyproject.toml index 7bbce9c..32b2df3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -21,7 +21,7 @@ classifiers = [ "Programming Language :: Python :: 3.11", "Programming Language :: Python :: Implementation :: CPython", ] -dependencies = ["diskcache==5.6.1"] +dependencies = ["diskcache==5.6.1", "fakeredis", "redis"] dynamic = ["version"] [project.urls] @@ -47,6 +47,7 @@ mypy-args = [ ] exclude = [ "simstring/database/disk.py", + "simstring/database/redis.py", "simstring/database/base.py", ] diff --git a/simstring/database/base.py b/simstring/database/base.py index 7210fe8..f8c16b1 100644 --- a/simstring/database/base.py +++ b/simstring/database/base.py @@ -13,3 +13,6 @@ def max_feature_size(self): def lookup_strings_by_feature_set_size_and_feature(self, size, feature): raise NotImplementedError + + def commit(self): + pass \ No newline at end of file diff --git a/simstring/database/redis.py b/simstring/database/redis.py new file mode 100644 index 0000000..416f2ba --- /dev/null +++ b/simstring/database/redis.py @@ -0,0 +1,65 @@ + +from typing import List, Set, Dict, Union +from .base import BaseDatabase +from simstring.feature_extractor.character_ngram import CharacterNgramFeatureExtractor +from simstring.feature_extractor.word_ngram import WordNgramFeatureExtractor + +from io import BufferedWriter +from redis import Redis +from fakeredis import FakeRedis +from functools import lru_cache + +import os + +FeatureExtractor = Union[ + CharacterNgramFeatureExtractor, WordNgramFeatureExtractor + ] + +class RedisDatabase(BaseDatabase): + def __init__( + self, + feature_extractor: FeatureExtractor, + redis_connection: Union[Redis,FakeRedis] = FakeRedis + + ): + self.feature_extractor = feature_extractor + self.feature_set_size_to_string_map = redis_connection(db=0, decode_responses=True) + self.feature_set_size_and_feature_to_string_map = redis_connection(db=1, decode_responses=True) + self._min_feature_size = 9999999 + self._max_feature_size = 0 + + + @staticmethod + def _make_key(size: int, feature: str) -> str: + return f"{size}-{feature}" + + def add(self, string: str) -> None: + features = self.feature_extractor.features(string) + size = len(features) + self.feature_set_size_to_string_map.sadd(size, string) + + self._min_feature_size = min(self._min_feature_size, size) + self._max_feature_size = max(self._max_feature_size, size) + + for feature in features: + self.feature_set_size_and_feature_to_string_map.sadd(self._make_key(size, feature), string) + + def all(self) -> List[str]: + strings = [] + for k in self.feature_set_size_to_string_map.keys(): + strings.extend(self.feature_set_size_to_string_map.smembers(k)) + return strings + + def lookup_strings_by_feature_set_size_and_feature( + self, size: int, feature: str + ) -> Set[str]: + return self.feature_set_size_and_feature_to_string_map.smembers(self._make_key(size, feature)) + + def min_feature_size(self) -> int: + return self._min_feature_size + + def max_feature_size(self) -> int: + return self._max_feature_size + + + diff --git a/tests/database/test_redis.py b/tests/database/test_redis.py new file mode 100644 index 0000000..7efeeca --- /dev/null +++ b/tests/database/test_redis.py @@ -0,0 +1,63 @@ +# -*- coding:utf-8 -*- + +from unittest import TestCase +from simstring.database.redis import RedisDatabase +from simstring.feature_extractor.character_ngram import CharacterNgramFeatureExtractor +import pickle +import os +import shutil +from fakeredis import FakeRedis + +class TestRedis(TestCase): + strings = ["a", "ab", "abc", "abcd", "abcde"] + + def setUp(self): + self.db = RedisDatabase(CharacterNgramFeatureExtractor(2), redis_connection=FakeRedis) + for string in self.strings: + self.db.add(string) + + def test_strings(self): + self.assertEqual(sorted(self.db.all()), sorted(self.strings)) + + def test_min_feature_size(self): + self.assertEqual(self.db.min_feature_size(), 2) + + def test_max_feature_size(self): + self.assertEqual(self.db.max_feature_size(), 6) + + def test_lookup_strings_by_feature_set_size_and_feature(self): + self.assertEqual( + self.db.lookup_strings_by_feature_set_size_and_feature(4, "ab_1"), + set(["abc"]), + ) + self.assertEqual( + self.db.lookup_strings_by_feature_set_size_and_feature(3, "ab_1"), + set(["ab"]), + ) + self.assertEqual( + self.db.lookup_strings_by_feature_set_size_and_feature(2, "ab_1"), set([]) + ) + + def test_load_from_folder(self): + with open("test.pkl", "wb") as f: + pickle.dump(self.db, f) + + + with open("test.pkl", "rb") as f: + new = pickle.load(f) + + self.assertEqual(self.db._min_feature_size, new._min_feature_size) + self.assertEqual(self.db._max_feature_size, new._max_feature_size) + self.assertEqual( + self.db.feature_extractor.__class__, new.feature_extractor.__class__ + ) + self.assertEqual(self.db.feature_extractor.n, new.feature_extractor.n) + self.assertEqual( + set(self.db.feature_set_size_to_string_map.iterkeys()), set(new.feature_set_size_to_string_map.iterkeys()) + ) + self.assertEqual( + set(self.db.feature_set_size_and_feature_to_string_map.iterkeys()), + set(new.feature_set_size_and_feature_to_string_map.iterkeys()), + ) + + os.remove("test.pkl")