Commit 1231ed2a by lichenggang

综合model和head预测, 还差位号的训练

parent 94f90604
......@@ -29,7 +29,7 @@ if __name__ == "__main__":
def get_test_data():
import pandas as pd
import json
df = pd.read_excel(r'C:\Users\ICHUNT\Desktop\bomlist\51AB0571_ CCTV ASST询价_SZIMS.xlsx', header=None, sheet_name='1')
df = pd.read_excel(r'C:\Users\ICHUNT\Desktop\bomlist\51AB0571_ CCTV ASST询价_SZIMS.xlsx', header=None, sheet_name='3')
df.fillna('?', inplace=True)
dic_dft = df.to_dict(orient='list')
return json.dumps(dic_dft)
......@@ -40,5 +40,6 @@ if __name__ == "__main__":
data = get_test_data()
p = PredictorFac(model_config)
data = json.loads(data)
res = p.predict(data,predict_type='model')
print(data)
res = p.predict(data,predict_type='all')
print(res)
#!/usr/bin/env python
# -*- coding:utf-8 -*-
import re
from collections import Counter
from utils.log_manager import get_logger
from utils.robots import dd_send_msg
import pandas as pd
CATEGORY = ['二极管']
RIGHT_LEVEL = 0.7
SEQ_LEVEL = 0.5
CATE_LEVEL = 0.5
MULTI_SAME_LEVEL = 3
class BasePredictor:
'''
......@@ -23,6 +25,8 @@ class BasePredictor:
self.info = get_logger(self.name + '_info')
self.robot_msg = dd_send_msg
self.pd = pd
# 占位符
self.PLACEHOLDER = '?'
def predict(self, key):
raise NotImplementedError
......@@ -43,7 +47,7 @@ class BasePredictor:
dic_proba = {k: v for k, v in zip(classes, deal_list)}
return predictions[0], dic_proba
def isseq(self, data):
def is_seq(self, data):
"""
序号列预测
"""
......@@ -54,7 +58,7 @@ class BasePredictor:
rate = round(len(collect_seq) / len(data), 3)
return True if rate >= SEQ_LEVEL else False
def isnum(self, data):
def is_num(self, data):
"""
数量列预测
"""
......@@ -72,16 +76,34 @@ class BasePredictor:
def isNumberCol(self, kw):
"""
是否是数量列
是否是数量列辅助函数
"""
if isinstance(kw, str):
return re.match(r'(\d+)((K)|([\u4E00-\u9FA5]{1,3}))$', kw, re.M | re.I)
else:
return False
def is_catecol(self, data):
cates = []
for i in data:
for j in CATEGORY:
if j in str(i):
cates.append(i)
rate = round(len(cates) / len(data), 3)
return rate >= CATE_LEVEL
def is_multi_same(self, data):
no_null_data = list(filter(lambda x: x != self.PLACEHOLDER, data))
result = Counter(no_null_data)
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
return li_sort[0][1] >= MULTI_SAME_LEVEL
def valid_seq(self, data):
"""取反"""
return not self.isseq(data)
return not self.is_seq(data)
def valid_num(self, data):
return not self.isnum(data)
return not self.is_num(data)
def valid_cate(self, data):
return not self.is_catecol(data)
\ No newline at end of file
......@@ -5,19 +5,19 @@ from collections import Counter
from predict.base_handler import BasePredictor
# 可能的头部字段
prob_fields = ["序号", "名称", "规格", "MPN", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注",
"需求数量", "售价",
PROB_FIELDS = ["序号", "名称", "规格", "MPN", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注",
"需求数量", "售价", "封装", "封装规格",
"参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面"]
# 标准名和代名词的映射
fields_map = {"序号": ["序号"],
STD_FIELDS_MAP = {
"类别": ["类别", "分类", "名称", "类别名称"],
"参数": ["参数", "规格", "描述"],
"参数": ["参数", "规格", "描述", "值"],
"型号": ["型号", "参考料号", "料号", "MPN", "厂商编码"],
"数量": ["数量", "用量(pcs)", "PCS", "用量", "用量(PCS)", "pcs"],
"封装": ["封装", "封装规格"],
"品牌": ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商"]}
# 必须返回也必须验证的标准字段
MUST_STD_FIELDS = ['参数', '数量']
#
......@@ -28,16 +28,19 @@ en_to_zh_map = {'brand_name': '品牌', 'param': '参数', 'goods_name': '型号
def fun(seri):
li_seri = seri.tolist()
for field in li_seri:
if str(field).lower() in prob_fields:
if str(field).lower() in PROB_FIELDS:
return field, seri.name
def repeat_max(li):
result = Counter(li)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
return li_sort[0][0]
# 取前多少行
HEAD_ROW = 7
# 空置率阈值
NAN_RATE = 0.8
# 占位符
PLACEHOLDER = '?'
class DicPredict(BasePredictor):
......@@ -57,14 +60,14 @@ class DicPredict(BasePredictor):
'pronoun': field,
'column_name': column_name
}
for k, v in fields_map.items():
if field.lower() in fields_map[k]:
for k, v in STD_FIELDS_MAP.items():
if field.lower() in STD_FIELDS_MAP[k]:
dic['std_name'] = k
li_res.append(dic)
return li_res
def pre_predict(self, dict_data):
def head_predict(self, dict_data):
columns = []
li_data = []
......@@ -76,15 +79,13 @@ class DicPredict(BasePredictor):
dft = df.T.head(HEAD_ROW)
dft.columns = columns
li_res_raw = self.id_by_field(dft)
std_result = []
ab_result = []
std_result = {}
ab_result = {}
for i in li_res_raw:
if i.get('std_name'):
dic_has_res = {i['column_name']: i['std_name']}
std_result.append(dic_has_res)
std_result[i['column_name']] = i['std_name']
else:
dic_ab_res = {i['column_name']: i['pronoun']}
ab_result.append(dic_ab_res)
ab_result[i['column_name']] = i['pronoun']
pre_id_res = {
'std_result': std_result,
'ab_result': ab_result,
......@@ -93,53 +94,73 @@ class DicPredict(BasePredictor):
def model_predict(self, dic_data):
"""
该方法目前只对[序号, 参数, 型号, 数量, 品牌]进行预测, 只有[参数, 数量]会强制有结果
该方法目前只对[序号(非标准), 数量, 类别, 参数, 型号, 品牌]进行预测, 前三者是非模型预测
:param dic_data:
:return:
:return :只有[参数, 数量]会强制有结果
"""
prob_columns = []
temp_pre_model_res = {}
ab_result = {}
for k, v in dic_data.items():
bol = self.v_chain(v)
if bol:
print(k, bol)
prob_columns.append(k)
continue
if self.isnum(v):
if self.is_seq(v):
ab_result[k] = '序号'
continue
if self.is_num(v):
temp_pre_model_res[k] = '数量'
continue
if self.isseq(v):
temp_pre_model_res[k] = '序号'
temp_dic_data = {k: list(filter(lambda x: x != PLACEHOLDER, dic_data[k])) for k in prob_columns}
if self.is_catecol(v):
temp_pre_model_res[k] = '类别'
continue
temp_dic_data = {k: list(filter(lambda x: x != self.PLACEHOLDER, dic_data[k])) for k in prob_columns}
for k, v in temp_dic_data.items():
li_single_pred_res = []
for string in v:
single_pred_res, probdic = self.get_single_predict(string)
li_single_pred_res.append(single_pred_res)
result = Counter(li_single_pred_res)
# [('brand_name', 4), ('goods_name', 3), ('param', 2)]
li_sort = sorted(result.items(), key=lambda x: x[1], reverse=True)
temp_pre_model_res[k] = en_to_zh_map[li_sort[0][0]]
sort_temp_pre_model_res = sorted(temp_pre_model_res.keys())
# pre_model_res = {}
result = repeat_max(li_single_pred_res)
temp_pre_model_res[k] = en_to_zh_map[result]
# 参数和型号列出现多条相同值则丢弃
prob_param_and_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数' or temp_pre_model_res[i] == '型号']
for col in prob_param_and_gn_cols:
if self.is_multi_same(temp_dic_data[col]):
temp_pre_model_res.pop(col)
model_id_res = {
'std_result': temp_pre_model_res,
'ab_result': ab_result,
}
return model_id_res
def predict(self, dic_data, predict_type='all'):
dic_data = self.pre_deal(dic_data)
if predict_type == 'all':
pre_id_res = self.pre_predict(dic_data)
if pre_id_res:
return pre_id_res
pre_id_res = self.head_predict(dic_data)
model_id_res = self.model_predict(dic_data)
# 表头预测和模型预测最后返回的数据进行综合处理
pre_std_result = pre_id_res.get('std_result')
model_std_result = model_id_res.get('std_result')
comprehensive_res = self.get_comprehensive_res(pre_std_result, model_std_result)
if comprehensive_res:
res = {
'std_result': comprehensive_res,
'ab_result': pre_id_res['ab_result'],
}
return res
elif predict_type == 'model':
model_id_res = self.model_predict(dic_data)
if model_id_res:
return model_id_res
elif predict_type == 'pre':
pre_id_res = self.pre_predict(dic_data)
elif predict_type == 'head':
pre_id_res = self.head_predict(dic_data)
if pre_id_res:
return pre_id_res
......@@ -149,7 +170,7 @@ class DicPredict(BasePredictor):
# 去掉空置率大于等于0.8的列
counter = 0
for item in v:
if str(item).strip() == PLACEHOLDER:
if str(item).strip() == self.PLACEHOLDER:
counter += 1
if counter / len(v) <= NAN_RATE:
new_dic_data[k] = v
......@@ -169,3 +190,24 @@ class DicPredict(BasePredictor):
return False
else:
return True
def get_comprehensive_res(self, pre_std_result, model_std_result):
vote_count = {
"类别": [],
"参数": [],
"型号": [],
"数量": [],
"品牌": []
}
for k, v in pre_std_result.items():
vote_count[v].append(k)
for k, v in model_std_result.items():
vote_count[v].append(k)
comprehensive_res = {}
for std_name, col_li in vote_count.items():
if len(col_li) >= 1:
col = repeat_max(col_li)
comprehensive_res[col] = std_name
return comprehensive_res
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment