Coverage for test/unit/kernel/test_decision.py: 99%
372 statements
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-12 17:09 +0000
« prev ^ index » next coverage.py v7.6.12, created at 2025-03-12 17:09 +0000
1"""Test for decision.py"""
2import json
3import os
4import tempfile
5import unittest
6from unittest.mock import patch
8import pint
10from bim2sim.kernel import decision
11from bim2sim.kernel.decision import BoolDecision, save, load, RealDecision, \
12 DecisionBunch, ListDecision, GuidDecision
13from bim2sim.kernel.decision.console import ConsoleDecisionHandler
14from bim2sim.elements.mapping.units import ureg
17class DecisionTestBase(unittest.TestCase):
18 """Base class for Decision tests"""
21class TestDecision(DecisionTestBase):
22 """General Decision related tests"""
24 def test_decision_value(self):
25 """test decision value consistency"""
27 dec = decision.RealDecision(question="??")
29 with self.assertRaises(ValueError):
30 value = dec.value
31 self.assertFalse(dec.valid())
33 dec.value = 5.
34 self.assertTrue(dec.valid())
35 self.assertEqual(dec.value, 5)
36 self.assertIsInstance(dec.value, pint.Quantity)
38 def test_invalid_value(self):
39 """test float value for IntDecision"""
40 dec = RealDecision(question="??")
41 self.assertFalse(dec.valid())
43 with self.assertRaises(ValueError):
44 dec.value = 'five'
46 def test_skip(self):
47 """test skipping decisions"""
48 dec1 = BoolDecision(
49 question="??",
50 key="key1",
51 allow_skip=True)
52 self.assertFalse(dec1.valid())
53 dec1.skip()
54 self.assertTrue(dec1.valid())
55 self.assertIsNone(dec1.value)
57 dec2 = BoolDecision(
58 question="??",
59 key="key2",
60 allow_skip=False)
62 with self.assertRaises(decision.DecisionException):
63 dec2.skip()
65 def test_freeze_decision(self):
66 """Test freezing a decision and change value."""
67 dec = BoolDecision('??')
68 with self.assertRaises(AssertionError):
69 dec.freeze()
70 dec.value = True
71 dec.freeze()
72 with self.assertRaises(AssertionError):
73 dec.value = False
75 def check(self, value):
76 """validation func"""
77 return 0 < float(value) < 10
79 def test_save_load(self):
80 """test saving decisions an loading them"""
81 key_bool = "key_bool"
82 key_real = "key_real"
84 dec_bool = BoolDecision(
85 question="??",
86 global_key=key_bool)
87 dec_bool.value = False
88 self.assertFalse(dec_bool.value)
90 dec_real = RealDecision(
91 question="??",
92 validate_func=self.check,
93 global_key=key_real)
94 dec_real.value = 5.
95 self.assertEqual(dec_real.value, 5)
97 decisions = DecisionBunch((dec_bool, dec_real))
98 with tempfile.TemporaryDirectory(prefix='bim2sim_') as directory:
99 path = os.path.join(directory, "mixed")
100 save(decisions, path)
102 # clear variables to simulate program restart
103 dec_real.reset()
104 dec_bool.reset()
105 del decisions
107 loaded_decisions = load(path)
109 dec_real.reset_from_deserialized(loaded_decisions[key_real])
110 dec_bool.reset_from_deserialized(loaded_decisions[key_bool])
111 self.assertEqual(dec_real.value, 5)
112 self.assertFalse(dec_bool.value)
114 def test_decision_reduce_by_key(self):
115 """tests the get_reduced_bunch function with same keys."""
116 dec_1 = BoolDecision(key='key1', question="??")
117 dec_2 = BoolDecision(key='key1', question="??")
118 dec_3 = BoolDecision(key='key2', question="??")
119 dec_bunch_orig = DecisionBunch([dec_1, dec_2, dec_3])
120 dec_bunch_exp = DecisionBunch([dec_1, dec_3])
121 doubled_bunch_exp = DecisionBunch([dec_2])
122 red_bunch, doubled_dec = dec_bunch_orig.get_reduced_bunch(criteria='key')
123 self.assertListEqual(dec_bunch_exp, red_bunch)
124 self.assertListEqual(doubled_dec, doubled_bunch_exp)
126 def test_decision_reduce_by_question(self):
127 """tests the get_reduced_bunch function with same questions."""
128 dec_1 = BoolDecision(key='key1', question="question A ?")
129 dec_2 = BoolDecision(key='key2', question="question A ?")
130 dec_3 = BoolDecision(key='key3', question="question B ?")
131 dec_bunch_orig = DecisionBunch([dec_1, dec_2, dec_3])
132 dec_bunch_exp = DecisionBunch([dec_1, dec_3])
133 doubled_bunch_exp = DecisionBunch([dec_2])
134 red_bunch, doubled_dec = dec_bunch_orig.get_reduced_bunch(
135 criteria='question')
136 self.assertListEqual(dec_bunch_exp, red_bunch)
137 self.assertListEqual(doubled_dec, doubled_bunch_exp)
140class TestBoolDecision(DecisionTestBase):
141 """test BoolDecisions"""
143 def test_decision_value(self):
144 """test interpreting input"""
146 dec = BoolDecision(question="??")
147 dec.value = True
148 self.assertTrue(dec.value)
150 dec2 = BoolDecision(question="??")
151 dec2.value = False
152 self.assertFalse(dec2.value)
154 def test_validation(self):
155 """test value validation"""
157 dec = decision.BoolDecision(question="??")
159 self.assertTrue(dec.validate(True))
160 self.assertTrue(dec.validate(False))
161 self.assertFalse(dec.validate(None))
162 self.assertFalse(dec.validate(0))
163 self.assertFalse(dec.validate(1))
164 self.assertFalse(dec.validate('y'))
166 def test_save_load(self):
167 """test saving decisions an loading them"""
168 key = "key1"
169 dec = decision.BoolDecision(question="??", global_key=key)
170 dec.value = True
171 self.assertTrue(dec.value)
173 # check serialize
174 serialized = json.dumps(dec.get_serializable())
175 deserialized = json.loads(serialized)
177 # check reset
178 dec.reset()
179 self.assertFalse(dec.valid())
180 dec.reset_from_deserialized(deserialized)
182 self.assertTrue(dec.value)
185@patch('builtins.print', lambda *args, **kwargs: None)
186class TestRealDecision(DecisionTestBase):
187 """test RealDecisions"""
189 def check(self, value):
190 """validation func"""
191 return 0 < float(value.m_as('m')) < 10
193 def test_validation(self):
194 """test value validation"""
195 unit = ureg.meter
196 dec = decision.RealDecision(question="??", unit=unit)
198 self.assertTrue(dec.validate(5. * unit))
199 self.assertTrue(dec.validate(15. * unit))
200 self.assertTrue(dec.validate(5.))
201 self.assertTrue(dec.validate(5))
202 self.assertFalse(dec.validate(False))
204 dec_val = decision.RealDecision(question="??", unit=unit, validate_func=self.check)
206 self.assertTrue(dec_val.validate(5. * unit))
207 self.assertFalse(dec_val.validate(15. * unit))
208 self.assertTrue(dec_val.validate(5 * unit))
209 self.assertFalse(dec_val.validate(False))
211 def test_save_load(self):
212 """test saving decisions an loading them"""
213 key1 = "key1"
214 key2 = "key2"
215 unit = ureg.meter
216 dec1 = RealDecision(
217 question="??",
218 global_key=key1)
219 dec1.value = 5.
220 dec2 = RealDecision(
221 question="??",
222 unit=unit,
223 validate_func=self.check,
224 global_key=key2)
225 dec2.value = 5.
227 self.assertTrue(dec1.value)
228 self.assertTrue(dec2.value)
230 # check serialize
231 serialized1 = json.dumps(dec1.get_serializable())
232 serialized2 = json.dumps(dec2.get_serializable())
233 deserialized1 = json.loads(serialized1)
234 deserialized2 = json.loads(serialized2)
236 # check reset
237 dec1.reset()
238 dec2.reset()
239 self.assertFalse(dec1.valid())
240 self.assertFalse(dec2.valid())
241 dec1.reset_from_deserialized(deserialized1)
242 dec2.reset_from_deserialized(deserialized2)
244 self.assertTrue(dec1.value)
245 self.assertTrue(dec2.value)
247 self.assertEqual(dec1.value, 5.)
248 self.assertIsInstance(dec1.value, pint.Quantity)
250 self.assertEqual(dec2.value.m_as('m'), 5)
251 self.assertIsInstance(dec2.value.m, float)
254# IntDecision not implemented
257class TestListDecision(DecisionTestBase):
259 def setUp(self) -> None:
260 super().setUp()
261 self.choices = [
262 ('a', 'option1'),
263 ('b', 'option2'),
264 ('c', 'option3')
265 ]
267 def test_validation(self):
268 """test value validation"""
269 dec = decision.ListDecision("??", choices=self.choices)
271 self.assertTrue(dec.validate('a'))
272 self.assertTrue(dec.validate('c'))
273 self.assertFalse(dec.validate('1'))
274 self.assertFalse(dec.validate(1))
275 self.assertFalse(dec.validate(3))
277 def test_save_load(self):
278 """test saving decisions an loading them"""
279 key = "key1"
280 dec = decision.ListDecision(
281 question="??",
282 choices=self.choices,
283 global_key=key)
284 dec.value = 'b'
285 self.assertTrue(dec.value)
287 # check serialize
288 serialized = json.dumps(dec.get_serializable())
289 deserialized = json.loads(serialized)
291 # check reset
292 dec.reset()
293 self.assertFalse(dec.valid())
294 dec.reset_from_deserialized(deserialized)
296 self.assertEqual('b', dec.value)
297 self.assertIsInstance(dec.value, str)
300class TestStringDecision(DecisionTestBase):
301 """test RealDecisions"""
303 def check(self, value):
304 """validation func"""
305 return value == 'success'
307 def test_validation(self):
308 """test value validation"""
309 dec = decision.StringDecision(question="??")
311 self.assertTrue(dec.validate('1'))
312 self.assertTrue(dec.validate('test'))
313 self.assertFalse(dec.validate(1))
314 self.assertFalse(dec.validate(None))
315 self.assertFalse(dec.validate(''))
317 dec_val = decision.StringDecision(question="??", validate_func=self.check)
319 self.assertTrue(dec_val.validate('success'))
320 self.assertFalse(dec_val.validate('other'))
322 def test_save_load(self):
323 """test saving decisions an loading them"""
324 key1 = "key1"
325 key2 = "key2"
326 dec1 = decision.StringDecision(
327 question="??",
328 global_key=key1)
329 dec1.value = 'success'
330 dec2 = decision.StringDecision(
331 question="??",
332 validate_func=self.check,
333 global_key=key2)
334 dec2.value = 'success'
336 self.assertEqual('success', dec1.value)
337 self.assertEqual('success', dec2.value)
339 # check serialize
340 serialized1 = json.dumps(dec1.get_serializable())
341 serialized2 = json.dumps(dec2.get_serializable())
342 deserialized1 = json.loads(serialized1)
343 deserialized2 = json.loads(serialized2)
345 # check reset
346 dec1.reset()
347 dec2.reset()
348 self.assertFalse(dec1.valid())
349 self.assertFalse(dec2.valid())
350 dec1.reset_from_deserialized(deserialized1)
351 dec2.reset_from_deserialized(deserialized2)
353 self.assertEqual('success', dec1.value)
354 self.assertEqual('success', dec2.value)
357class TestGuidDecision(DecisionTestBase):
359 def check(self, value):
360 """validation func"""
361 valids = (
362 '2tHa09veL10P9$ol9urWrT',
363 '0otlA1TWvCPvzfXTM_RO1R',
364 '2GCvzU9J93CxAS3rHyr1a6'
365 )
366 return all(guid in valids for guid in value)
368 def test_validation(self):
369 """test value validation"""
370 dec = decision.GuidDecision(question="??", multi=False)
372 self.assertTrue(dec.validate({'2tHa09veL10P9$ol9urWrT'}))
373 self.assertFalse(dec.validate({'2tHa09veL10P9$ol9urWrT, 0otlA1TWvCPvzfXTM_RO1R'}), 'multi not allowed')
374 self.assertFalse(dec.validate({'2tHa09veL10P9$ol'}))
375 self.assertFalse(dec.validate(''))
376 self.assertFalse(dec.validate(1))
377 self.assertFalse(dec.validate(None))
379 dec_multi = decision.GuidDecision(question="??", validate_func=self.check, multi=True)
381 self.assertTrue(dec_multi.validate({'2tHa09veL10P9$ol9urWrT'}))
382 self.assertTrue(dec_multi.validate({'2tHa09veL10P9$ol9urWrT', '0otlA1TWvCPvzfXTM_RO1R'}))
383 self.assertFalse(dec_multi.validate({'2tHa09veL10P9$ol9urWrT', 'GUID_not_in_valid_list'}))
385 def test_save_load(self):
386 """test saving decisions an loading them"""
387 key1 = "key1"
388 key2 = "key2"
389 guid1 = {'2tHa09veL10P9$ol9urWrT'}
390 guid2 = {'2tHa09veL10P9$ol9urWrT', '2GCvzU9J93CxAS3rHyr1a6'}
392 dec1 = decision.GuidDecision(
393 question="??",
394 global_key=key1)
395 dec1.value = guid1
396 dec2 = decision.GuidDecision(
397 question="??",
398 multi=True,
399 validate_func=self.check,
400 global_key=key2)
401 dec2.value = guid2
403 self.assertSetEqual(guid1, dec1.value)
404 self.assertSetEqual(guid2, dec2.value)
406 # check serialize
407 serialized1 = json.dumps(dec1.get_serializable())
408 serialized2 = json.dumps(dec2.get_serializable())
409 deserialized1 = json.loads(serialized1)
410 deserialized2 = json.loads(serialized2)
412 # check reset
413 dec1.reset()
414 dec2.reset()
415 self.assertFalse(dec1.valid())
416 self.assertFalse(dec2.valid())
417 dec1.reset_from_deserialized(deserialized1)
418 dec2.reset_from_deserialized(deserialized2)
420 self.assertSetEqual(guid1, dec1.value)
421 self.assertSetEqual(guid2, dec2.value)
424@patch('builtins.print', lambda *args, **kwargs: None)
425class TestConsoleHandler(DecisionTestBase):
426 handler = ConsoleDecisionHandler()
428 def check(self, value):
429 return True
431 def test_default_value(self):
432 """test if default value is used on empty input"""
433 real_dec = RealDecision("??", unit=ureg.meter, default=10)
434 bool_dec = BoolDecision("??", default=False)
435 list_dec = ListDecision("??", choices="ABC", default="C")
436 decisions = DecisionBunch((real_dec, bool_dec, list_dec))
438 with patch('builtins.input', lambda *args, **kwargs: ''):
439 answers = self.handler.get_answers_for_bunch(decisions)
441 expected = [10 * ureg.m, False, 'C']
442 self.assertListEqual(expected, answers)
444 @patch('builtins.input', lambda *args, **kwargs: 'bla bla')
445 def test_bad_input(self):
446 """test behaviour on bad input"""
447 dec = BoolDecision(question="??")
448 bunch = DecisionBunch([dec])
449 with self.assertRaises(decision.DecisionCancel):
450 self.handler.get_answers_for_bunch(bunch)
452 def test_skip_all(self):
453 """test skipping collected decisions"""
454 decisions = DecisionBunch()
455 for i in range(3):
456 key = "n%d" % i
457 decisions.append(BoolDecision(
458 question="??",
459 key=key,
460 allow_skip=True))
462 with patch('builtins.input', lambda *args, **kwargs: 'skip all'):
463 answers = self.handler.get_answers_for_bunch(decisions)
465 self.assertEqual(len(answers), 3)
466 self.assertFalse(any(answers))
468 def test_real_parsing(self):
469 """test input interpretation"""
470 expected_valids = {
471 '1': 1,
472 '1.': 1,
473 '1.0': 1,
474 '1.1': 1.1,
475 '1e0': 1,
476 '1e-1': 0.1,
477 }
478 unit = ureg.meter
480 for inp, res in expected_valids.items():
481 with patch('builtins.input', lambda *args: inp):
482 dec = RealDecision(question="??", unit=unit, validate_func=self.check)
483 answer = self.handler.user_input(dec)
484 self.assertEqual(res * unit, answer)
486 def test_bool_parse(self):
487 """test bool value parsing"""
489 dec = decision.BoolDecision(question="??")
490 self.assertFalse(dec.valid())
492 parsed_int1 = self.handler.parse(dec, '1')
493 self.assertTrue(parsed_int1)
494 self.assertIsInstance(parsed_int1, bool)
495 parsed_int0 = self.handler.parse(dec, '0')
496 self.assertFalse(parsed_int0)
497 self.assertIsInstance(parsed_int0, bool)
498 parsed_real = self.handler.parse(dec, '1.1')
499 self.assertIsNone(parsed_real)
500 parsed_str = self.handler.parse(dec, 'y')
501 self.assertTrue(parsed_str)
502 self.assertIsInstance(parsed_str, bool)
503 parsed_invalid = self.handler.parse(dec, 'foo')
504 self.assertIsNone(parsed_invalid)
506 def test_real_parse(self):
507 """test value parsing"""
509 dec = RealDecision(question="??")
510 self.assertFalse(dec.valid())
512 parsed_int = self.handler.parse(dec, 5)
513 self.assertEqual(parsed_int, 5.)
514 self.assertIsInstance(parsed_int, pint.Quantity)
515 parsed_real = self.handler.parse(dec, 5.)
516 self.assertEqual(parsed_real, 5.)
517 self.assertIsInstance(parsed_real, pint.Quantity)
518 parsed_str = self.handler.parse(dec, '5')
519 self.assertEqual(parsed_str, 5.)
520 self.assertIsInstance(parsed_str, pint.Quantity)
521 parsed_invalid = self.handler.parse(dec, 'five')
522 self.assertIsNone(parsed_invalid)
524 def test_list_parse(self):
525 """test value parsing"""
527 choices = [
528 ('a', 'option1'),
529 ('b', 'option2'),
530 ('c', 'option3')
531 ]
532 dec = ListDecision("??", choices=choices)
533 self.assertFalse(dec.valid())
535 parsed_int = self.handler.parse(dec, 0)
536 self.assertEqual(parsed_int, 'a')
538 parsed_real = self.handler.parse(dec, 2)
539 self.assertEqual(parsed_real, 'c')
541 parsed_str = self.handler.parse(dec, 3)
542 self.assertIsNone(parsed_str)
544 parsed_str = self.handler.parse(dec, 'a')
545 self.assertIsNone(parsed_str)
547 def test_string_parse(self):
548 """test string input"""
549 answers = ('success', 'other')
550 for inp in answers:
551 with patch('builtins.input', lambda *args: inp):
552 dec = decision.StringDecision(question="??")
553 answer = self.handler.user_input(dec)
554 self.assertEqual(inp, answer)
556 with patch('builtins.input', lambda *args: ''):
557 with self.assertRaises(decision.DecisionCancel):
558 dec = decision.StringDecision(question="??")
559 self.handler.user_input(dec)
561 def test_guid_parse(self):
563 def check(value):
564 valids = (
565 '2tHa09veL10P9$ol9urWrT',
566 '0otlA1TWvCPvzfXTM_RO1R',
567 '2GCvzU9J93CxAS3rHyr1a6'
568 )
569 return all(guid in valids for guid in value)
571 def valid_parsed(guid_decision, inp):
572 return guid_decision.validate(self.handler.parse_guid_input(inp))
574 # test parse + validate
575 dec = decision.GuidDecision(question="??", multi=False)
576 self.assertTrue(valid_parsed(dec, '2tHa09veL10P9$ol9urWrT'))
577 self.assertFalse(valid_parsed(dec, '2tHa09veL10P9$ol9urWrT, 0otlA1TWvCPvzfXTM_RO1R')) # multi not allowed
578 self.assertFalse(valid_parsed(dec, '2tHa09veL10P9$ol'))
579 self.assertFalse(valid_parsed(dec, ''))
580 self.assertFalse(valid_parsed(dec, 1))
581 self.assertFalse(valid_parsed(dec, None))
583 # test parse + validate in multi guid decision
584 dec_multi = decision.GuidDecision(question="??", validate_func=check, multi=True)
585 self.assertTrue(valid_parsed(dec_multi, '2tHa09veL10P9$ol9urWrT'))
586 self.assertTrue(valid_parsed(dec_multi, '2tHa09veL10P9$ol9urWrT, 0otlA1TWvCPvzfXTM_RO1R'))
587 self.assertTrue(valid_parsed(dec_multi, '2tHa09veL10P9$ol9urWrT 0otlA1TWvCPvzfXTM_RO1R'))
588 self.assertTrue(valid_parsed(dec_multi, '2tHa09veL10P9$ol9urWrT,0otlA1TWvCPvzfXTM_RO1R'))
589 self.assertFalse(valid_parsed(dec_multi, '2tHa09veL10P9$ol9urWrT, GUID_not_in_valid_list'))
590 self.assertFalse(valid_parsed(dec_multi, '2tHa09veL10P9$ol9urWrT; 0otlA1TWvCPvzfXTM_RO1R'))
592 # full test
593 dec1 = GuidDecision(question="??")
594 with patch('builtins.input', lambda *args: '2tHa09veL10P9$ol9urWrT'):
595 answer = self.handler.user_input(dec1)
596 self.assertSetEqual({'2tHa09veL10P9$ol9urWrT'}, answer)
598 dec2 = GuidDecision(question="??", validate_func=check, multi=True)
599 with patch('builtins.input', lambda *args: '2tHa09veL10P9$ol9urWrT, 0otlA1TWvCPvzfXTM_RO1R'):
600 answer2 = self.handler.user_input(dec2)
601 self.assertSetEqual({'2tHa09veL10P9$ol9urWrT', '0otlA1TWvCPvzfXTM_RO1R'}, answer2)
604if __name__ == '__main__':
605 unittest.main()