Coverage for source/utils/policy_from_string_converter.py: 100%
12 statements
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-30 15:13 +0000
« prev ^ index » next coverage.py v7.8.0, created at 2025-05-30 15:13 +0000
1# utils/policy_from_string_converter.py
3from typing import Any, Type
4from rl.policy import BoltzmannQPolicy, EpsGreedyQPolicy, LinearAnnealedPolicy, Policy
6from .base_from_string_converter import BaseFromStringConverter
8class PolicyFromStringConverter(BaseFromStringConverter):
9 """
10 Converts string identifiers to reinforcement learning policy classes.
12 This class implements a specific string-to-object conversion for RL policy
13 classes. It maps string names to their corresponding policy implementations,
14 with special handling for linear annealed policies that require wrapping
15 the base policy.
16 """
18 def __init__(self, **kwargs: dict[str, Any]) -> None:
19 """
20 Initializes the policy converter with options for policy configuration.
22 Parameters:
23 **kwargs (dict[str, Any]): Optional parameters that will be passed to
24 the constructor of the policy objects.
25 """
27 self._kwargs: dict[str, Any] = kwargs
28 self._value_map: dict[str, Type[Policy]] = {
29 'boltzmann': BoltzmannQPolicy,
30 'eps_greedy': EpsGreedyQPolicy,
31 'linear_annealed_boltzmann': BoltzmannQPolicy,
32 'linear_annealed_eps_greedy': EpsGreedyQPolicy
33 }
35 def convert_from_string(self, string_value: str) -> Policy:
36 """
37 Converts a string identifier to its corresponding policy object.
39 Overrides the base method to provide special handling for linear annealed
40 policies, which require wrapping the base policy in a LinearAnnealedPolicy.
42 Parameters:
43 string_value (str): String identifier of the policy.
45 Returns:
46 Policy: An instance of the appropriate policy class.
47 For linear annealed policies, returns the base policy wrapped
48 in a LinearAnnealedPolicy. Otherwise, returns the base policy.
49 """
51 if 'linear_annealed' in string_value and string_value in self._value_map:
52 converted_value = self._value_map.get(string_value)
53 return LinearAnnealedPolicy(converted_value(), **self._kwargs)
54 else:
55 return super().convert_from_string(string_value)