Commit 92d955a7 by lichenggang

1. 静态配置添加*

2. 模型判断里面对需要预测的列做去重
parent f055b375
...@@ -126,6 +126,38 @@ class BasePredictor: ...@@ -126,6 +126,38 @@ class BasePredictor:
rate = round(len(set(data)) / len(data), 3) rate = round(len(set(data)) / len(data), 3)
return rate return rate
@classmethod
def repeat_max(self, li):
result = Counter(li)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
return li_sort[0][0]
@classmethod
def get_comprehensive_res(self, pre_std_result, model_std_result):
print('表头预测结果', pre_std_result)
print('模型预测结果', model_std_result)
vote_count = {
"类别": [],
"参数": [],
"型号": [],
"数量": [],
"品牌": []
}
for k, v in pre_std_result.items():
vote_count[v].append(k)
for k, v in model_std_result.items():
vote_count[v].append(k)
# 此处并没有处理表头和模型两者预测冲突的情况, 但因为repeat_max方法是稳定的,所以冲突后会取第一个也就是表头预测的结果
comprehensive_res = {}
for std_name, col_li in vote_count.items():
if col_li:
col = BasePredictor.repeat_max(col_li)
comprehensive_res[col] = std_name
return comprehensive_res
if __name__ == "__main__": if __name__ == "__main__":
li = ['型号', '电阻', '电阻', '电阻', '电阻', '电阻', '电阻', '电阻', '电阻', '电阻', '电阻', '电阻', '电容', '电容', '电容', '电容', '电容', '电容', '电容', '电容', '电容', '电感', '电感', '电感', '电感', '电感', '二极管', '二极管', '二极管', '二极管', '二极管', '二极管', '二极管', '场效应管', '场效应管', '场效应管', '场效应管', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '集成电路', '振荡器', '振荡器', '光电器件', '光电器件', '光电器件', '磁珠', '保险丝', '保险丝', '保险丝', '开关元件', '继电器', '继电器', '继电器', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '接插件', '模块', '模块', '模块'] pass
print(set(li)) \ No newline at end of file
#!/usr/bin/env python #!/usr/bin/env python
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
from collections import Counter
from predict.base_handler import BasePredictor from predict.base_handler import BasePredictor
from static_config import * from static_config import *
...@@ -13,13 +11,6 @@ def fun(seri): ...@@ -13,13 +11,6 @@ def fun(seri):
return field, seri.name return field, seri.name
def repeat_max(li):
result = Counter(li)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
return li_sort[0][0]
# 取前多少行 # 取前多少行
HEAD_ROW = 7 HEAD_ROW = 7
# 空置率阈值 # 空置率阈值
...@@ -82,12 +73,12 @@ class DicPredict(BasePredictor): ...@@ -82,12 +73,12 @@ class DicPredict(BasePredictor):
for col in list_num_col_and_param_col: for col in list_num_col_and_param_col:
std_result.pop(col) std_result.pop(col)
pre_id_res = { head_id_res = {
'std_result': std_result, 'std_result': std_result,
'ab_result': ab_result, 'ab_result': ab_result,
} }
return pre_id_res return head_id_res
def model_predict(self, dic_data): def model_predict(self, dic_data):
""" """
...@@ -114,24 +105,28 @@ class DicPredict(BasePredictor): ...@@ -114,24 +105,28 @@ class DicPredict(BasePredictor):
if self.is_catecol(no_null_v): if self.is_catecol(no_null_v):
temp_pre_model_res[k] = '类别' temp_pre_model_res[k] = '类别'
not_null_dic_data = {k: list(filter(lambda x: x != self.PLACEHOLDER, dic_data[k])) for k in prob_columns} # 对列元素进行去重并处理掉占位符
for k, v in not_null_dic_data.items(): set_not_null_dic_data = {k: set(list(filter(lambda x: x != self.PLACEHOLDER, dic_data[k]))) for k in
prob_columns}
for k, v in set_not_null_dic_data.items():
li_single_pred_res = [] li_single_pred_res = []
for string in v: for string in v:
single_pred_res, probdic = self.get_single_predict(string) single_pred_res, probdic = self.get_single_predict(string)
li_single_pred_res.append(single_pred_res) li_single_pred_res.append(single_pred_res)
result = repeat_max(li_single_pred_res) result = BasePredictor.repeat_max(li_single_pred_res)
# 如果该列被预测为其他, 则不做改动 # 如果该列被预测为其他, 则不做改动
if result == 'other': if result == 'other':
continue continue
temp_pre_model_res[k] = EN_TO_ZH_MAP[result] temp_pre_model_res[k] = EN_TO_ZH_MAP[result]
# 若有多个参数列或型号列,则进行不同率的比较, 不同率最高的选为目标列 # 若有多个参数列或型号列,则进行不同率的比较, 不同率最高的选为目标列
not_null_dic_data = {k: list(filter(lambda x: x != self.PLACEHOLDER, dic_data[k])) for k in prob_columns}
prob_param_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数'] prob_param_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数']
prob_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '型号'] prob_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '型号']
for param_or_gn_col_list in [prob_param_cols, prob_gn_cols]: for param_or_gn_col_list in [prob_param_cols, prob_gn_cols]:
if len(param_or_gn_col_list) >= 2: if len(param_or_gn_col_list) >= 2:
li_diffrate = [(col, BasePredictor.get_diffrate(not_null_dic_data[col])) for col in param_or_gn_col_list] li_diffrate = [(col, BasePredictor.get_diffrate(not_null_dic_data[col])) for col in
param_or_gn_col_list]
sort_li_diffrate = sorted(li_diffrate, key=lambda x: x[1], reverse=True) sort_li_diffrate = sorted(li_diffrate, key=lambda x: x[1], reverse=True)
for col_diffrate in sort_li_diffrate[1:]: for col_diffrate in sort_li_diffrate[1:]:
temp_pre_model_res.pop(col_diffrate[0]) temp_pre_model_res.pop(col_diffrate[0])
...@@ -146,18 +141,18 @@ class DicPredict(BasePredictor): ...@@ -146,18 +141,18 @@ class DicPredict(BasePredictor):
dic_data = self.pre_deal(dic_data) dic_data = self.pre_deal(dic_data)
if predict_type == 'all': if predict_type == 'all':
pre_id_res = self.head_predict(dic_data) head_id_res = self.head_predict(dic_data)
model_id_res = self.model_predict(dic_data) model_id_res = self.model_predict(dic_data)
# 表头预测和模型预测最后返回的数据进行综合处理 # 表头预测和模型预测最后返回的数据进行综合处理
pre_std_result = pre_id_res.get('std_result') head_std_result = head_id_res.get('std_result')
model_std_result = model_id_res.get('std_result') model_std_result = model_id_res.get('std_result')
comprehensive_res = self.get_comprehensive_res(pre_std_result, model_std_result) comprehensive_res = BasePredictor.get_comprehensive_res(head_std_result, model_std_result)
if comprehensive_res: if comprehensive_res:
res = { res = {
'std_result': comprehensive_res, 'std_result': comprehensive_res,
'ab_result': pre_id_res['ab_result'], 'ab_result': head_id_res['ab_result'],
} }
return res return res
...@@ -167,9 +162,9 @@ class DicPredict(BasePredictor): ...@@ -167,9 +162,9 @@ class DicPredict(BasePredictor):
return model_id_res return model_id_res
elif predict_type == 'head': elif predict_type == 'head':
pre_id_res = self.head_predict(dic_data) head_id_res = self.head_predict(dic_data)
if pre_id_res: if head_id_res:
return pre_id_res return head_id_res
def pre_deal(self, dic_data): def pre_deal(self, dic_data):
new_dic_data = {} new_dic_data = {}
...@@ -197,26 +192,3 @@ class DicPredict(BasePredictor): ...@@ -197,26 +192,3 @@ class DicPredict(BasePredictor):
return False return False
else: else:
return True return True
def get_comprehensive_res(self, pre_std_result, model_std_result):
print('表头预测结果', pre_std_result)
print('模型预测结果', model_std_result)
vote_count = {
"类别": [],
"参数": [],
"型号": [],
"数量": [],
"品牌": []
}
for k, v in pre_std_result.items():
vote_count[v].append(k)
for k, v in model_std_result.items():
vote_count[v].append(k)
comprehensive_res = {}
for std_name, col_li in vote_count.items():
if len(col_li) >= 1:
col = repeat_max(col_li)
comprehensive_res[col] = std_name
return comprehensive_res
...@@ -2,16 +2,16 @@ ...@@ -2,16 +2,16 @@
# -*- coding:utf-8 -*- # -*- coding:utf-8 -*-
# 可能的头部字段 # 可能的头部字段
PROB_FIELDS = ["序号", "名称", "规格", "MPN", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注", PROB_FIELDS = ["序号", "名称", "规格", "MPN", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注",
"需求数量", "售价", "封装", "封装规格", "需求数量", "售价", "封装", "封装规格", '*参数', '*型号', '*数量', '*品牌',
"参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面"] "参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面"]
# 标准名和代名词的映射 # 标准名和代名词的映射
STD_FIELDS_MAP = { STD_FIELDS_MAP = {
"类别": ["类别", "分类", "名称", "类别名称"], "类别": ["类别", "分类", "名称", "类别名称"],
"参数": ["参数", "规格", "描述", "值"], "参数": ["参数", "*参数", "规格", "描述", "值"],
"型号": ["型号", "参考料号", "料号", "MPN", "厂商编码"], "型号": ["型号", "*型号", "参考料号", "料号", "MPN", "厂商编码"],
"数量": ["数量", "用量(pcs)", "PCS", "用量", "用量(PCS)", "pcs"], "数量": ["数量", "*数量", "用量(pcs)", "PCS", "用量", "用量(PCS)", "pcs"],
"品牌": ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商"]} "品牌": ["品牌", "*品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商"]}
# 必须返回也必须验证的标准字段 # 必须返回也必须验证的标准字段
MUST_STD_FIELDS = ['参数', '数量'] MUST_STD_FIELDS = ['参数', '数量']
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment