Coverage for source/environment/labeled_data_balancer.py: 40%

10 statements  

« prev     ^ index     » next       coverage.py v7.8.0, created at 2025-07-27 17:11 +0000

1# environment/labeled_data_balancer.py 

2 

3# global imports 

4from imblearn.base import BaseSampler 

5 

6# local imports 

7 

8class LabeledDataBalancer: 

9 """ 

10 Implements a labeled data balancer that uses a list of samplers to balance the input and output data. 

11 """ 

12 

13 def __init__(self, balancers: list[BaseSampler]) -> None: 

14 """ 

15 Class constructor. Initializes the balancer with a list of samplers. 

16 

17 Parameters: 

18 balancers (list[BaseSampler]): A list of samplers to be used for balancing the data. 

19 """ 

20 

21 self.__balancers = balancers 

22 

23 def balance(self, input_data: list[list[float]], output_data: list[int]) -> tuple[list[list[float]], list[int]]: 

24 """ 

25 Balances the input and output data using the configured samplers. 

26 

27 Parameters: 

28 input_data (list[list[float]]): The input data to be balanced. 

29 output_data (list[int]): The output data to be balanced. 

30 

31 Raises: 

32 ValueError: If the input data and output data do not have the same length. 

33 

34 Returns: 

35 (tuple[list[list[float]], list[int]]): A tuple containing the balanced input data and output data. 

36 """ 

37 

38 if len(input_data) != len(output_data): 

39 raise ValueError("Input data and output data must have the same length.") 

40 

41 for balancer in self.__balancers: 

42 input_data, output_data = balancer.fit_resample(input_data, output_data) 

43 

44 return input_data, output_data