Module monk.pytorch.datasets.class_imbalance

Expand source code
from pytorch.datasets.imports import *
from system.imports import *




def balance_class_weights(label_list, nclasses):                        
    count = [0] * nclasses
    pbar=tqdm(total=len(label_list));
    for idx, val in enumerate(label_list):       
        pbar.update();
        count[val] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(label_list)   
    pbar=tqdm(total=len(label_list));
    for idx, val in enumerate(label_list):  
        pbar.update();
        weight[idx] = weight_per_class[val]                                 
    return weight, weight_per_class, count;

Functions

def balance_class_weights(label_list, nclasses)
Expand source code
def balance_class_weights(label_list, nclasses):                        
    count = [0] * nclasses
    pbar=tqdm(total=len(label_list));
    for idx, val in enumerate(label_list):       
        pbar.update();
        count[val] += 1                                                     
    weight_per_class = [0.] * nclasses                                      
    N = float(sum(count))                                                   
    for i in range(nclasses):                                                   
        weight_per_class[i] = N/float(count[i])                                 
    weight = [0] * len(label_list)   
    pbar=tqdm(total=len(label_list));
    for idx, val in enumerate(label_list):  
        pbar.update();
        weight[idx] = weight_per_class[val]                                 
    return weight, weight_per_class, count;