#!/usr/bin/env python
# -*- coding:utf-8 -*-
from predict.base_handler import BasePredictor
from static_config import EN_TO_ZH_MAP

class KwPredict(BasePredictor):

    def predict(self, string_or_list, predict_type='all'):
        self.info.info('预测类型为: %s, 接收数据: %s' % (predict_type, string_or_list))
        if predict_type == 'all':
            result, prab = self.get_single_predict(string_or_list, BasePredictor.model, BasePredictor.extractor)
            zh_prob = {EN_TO_ZH_MAP[k]: v for k, v in prab.items()}
            res = {'result': EN_TO_ZH_MAP[result], 'probably_dict': zh_prob, 'predict_type': predict_type}
        else:
            model = getattr(BasePredictor, predict_type + '_model')
            extractor = getattr(BasePredictor, predict_type + '_extractor')
            dic_result = {}
            if isinstance(string_or_list, list):
                for i in string_or_list:
                    result, prab = self.get_single_predict(i, model, extractor)
                    dic_result[i] = int(result[0])
                res = {'result': dic_result, 'predict_type': predict_type}
            else:
                result= self.get_single_predict(string_or_list, model, extractor)
                res = {'result': int(result[0]), 'predict_type': predict_type}

        return res