Coverage for source/utils/policy_from_string_converter.py: 100%

12 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-05-30 15:13 +0000

1# utils/policy_from_string_converter.py 

2 

3from typing import Any, Type 

4from rl.policy import BoltzmannQPolicy, EpsGreedyQPolicy, LinearAnnealedPolicy, Policy 

5 

6from .base_from_string_converter import BaseFromStringConverter 

7 

8class PolicyFromStringConverter(BaseFromStringConverter): 

9 """ 

10 Converts string identifiers to reinforcement learning policy classes. 

11 

12 This class implements a specific string-to-object conversion for RL policy 

13 classes. It maps string names to their corresponding policy implementations, 

14 with special handling for linear annealed policies that require wrapping 

15 the base policy. 

16 """ 

17 

18 def __init__(self, **kwargs: dict[str, Any]) -> None: 

19 """ 

20 Initializes the policy converter with options for policy configuration. 

21 

22 Parameters: 

23 **kwargs (dict[str, Any]): Optional parameters that will be passed to 

24 the constructor of the policy objects. 

25 """ 

26 

27 self._kwargs: dict[str, Any] = kwargs 

28 self._value_map: dict[str, Type[Policy]] = { 

29 'boltzmann': BoltzmannQPolicy, 

30 'eps_greedy': EpsGreedyQPolicy, 

31 'linear_annealed_boltzmann': BoltzmannQPolicy, 

32 'linear_annealed_eps_greedy': EpsGreedyQPolicy 

33 } 

34 

35 def convert_from_string(self, string_value: str) -> Policy: 

36 """ 

37 Converts a string identifier to its corresponding policy object. 

38 

39 Overrides the base method to provide special handling for linear annealed 

40 policies, which require wrapping the base policy in a LinearAnnealedPolicy. 

41 

42 Parameters: 

43 string_value (str): String identifier of the policy. 

44 

45 Returns: 

46 Policy: An instance of the appropriate policy class. 

47 For linear annealed policies, returns the base policy wrapped 

48 in a LinearAnnealedPolicy. Otherwise, returns the base policy. 

49 """ 

50 

51 if 'linear_annealed' in string_value and string_value in self._value_map: 

52 converted_value = self._value_map.get(string_value) 

53 return LinearAnnealedPolicy(converted_value(), **self._kwargs) 

54 else: 

55 return super().convert_from_string(string_value)