Commit 7756c9be by lichenggang

增加有效数据行表头预测

parent ed3f93c2
...@@ -17,22 +17,15 @@ class Classify(classify_pb2_grpc.classifyServicer): ...@@ -17,22 +17,15 @@ class Classify(classify_pb2_grpc.classifyServicer):
def Classify(self, request, context): def Classify(self, request, context):
bom_log.info(f'grpc收到数据: {request.keyword}') bom_log.info(f'grpc收到数据: {request.keyword}')
err_msg = None
try: try:
dic_data = json.loads(request.keyword) dic_data = json.loads(request.keyword)
res = self.predictorfac.predict(dic_data) res = self.predictorfac.predict(dic_data)
res['status'] = 1 res['status'] = 1
if '数量' not in str(res): if '数量' not in str(res):
res = { err_msg = '未识别到数量列'
'status': 0,
'err_msg': f'未识别到数量列'
}
if '参数' not in str(res) and '型号' not in str(res): if '参数' not in str(res) and '型号' not in str(res):
res = { err_msg = '未识别到参数列'
'status': 0,
'err_msg': f'未识别到参数列'
}
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
res = self.predictorfac.predict(request.keyword) res = self.predictorfac.predict(request.keyword)
res['status'] = 1 res['status'] = 1
...@@ -42,6 +35,12 @@ class Classify(classify_pb2_grpc.classifyServicer): ...@@ -42,6 +35,12 @@ class Classify(classify_pb2_grpc.classifyServicer):
'err_msg': '未知' 'err_msg': '未知'
} }
bom_log.error(f'grpc报错: \n' + traceback.format_exc()) bom_log.error(f'grpc报错: \n' + traceback.format_exc())
if err_msg:
res = {
'status': 0,
'err_msg': err_msg
}
return classify_pb2.ClassifyReply(message='result {msg}'.format(msg=res)) return classify_pb2.ClassifyReply(message='result {msg}'.format(msg=res))
def fac_test_predic(self, data): def fac_test_predic(self, data):
......
...@@ -35,22 +35,17 @@ class KwHandler(BaseHandler): ...@@ -35,22 +35,17 @@ class KwHandler(BaseHandler):
class DictHandler(BaseHandler): class DictHandler(BaseHandler):
async def post(self): async def post(self):
err_msg = None
try: try:
bom_log.info(f'http收到识别字典: {self.request.body}') bom_log.info(f'http收到识别字典: {self.request.body}')
dic_data = json.loads(self.request.body) dic_data = json.loads(self.request.body)
res = self.predictorfac.predict(dic_data) res = self.predictorfac.predict(dic_data)
res['status'] = 1 res['status'] = 1
if '数量' not in str(res): if '数量' not in str(res):
res = { err_msg = '未识别到数量列'
'status': 0,
'err_msg': '未识别到数量列'
}
if '参数' not in str(res) and '型号' not in str(res): if '参数' not in str(res) and '型号' not in str(res):
res = { err_msg = '未识别到参数列'
'status': 0,
'err_msg': '未识别到参数列'
}
except json.decoder.JSONDecodeError: except json.decoder.JSONDecodeError:
res = self.predictorfac.predict(self.request.body) res = self.predictorfac.predict(self.request.body)
res['status'] = 1 res['status'] = 1
...@@ -60,6 +55,12 @@ class DictHandler(BaseHandler): ...@@ -60,6 +55,12 @@ class DictHandler(BaseHandler):
'err_msg': '未知' 'err_msg': '未知'
} }
bom_log.error(f'http字典识别报错: \n' + traceback.format_exc()) bom_log.error(f'http字典识别报错: \n' + traceback.format_exc())
if err_msg:
res = {
'status': 0,
'err_msg': err_msg
}
self.write(res) self.write(res)
......
...@@ -147,22 +147,24 @@ class BasePredictor: ...@@ -147,22 +147,24 @@ class BasePredictor:
return li_sort[0][0] return li_sort[0][0]
@classmethod @classmethod
def get_comprehensive_res(cls, head_std_result, model_std_result): def get_comprehensive_res(cls, head_result, model_result, major='head', vote_type='standard'):
if vote_type == 'standard':
vote_count = { vote_count = {k: []for k in STD_FIELDS_MAP.keys()}
"类别": [], else:
"参数": [], set_k = set()
"型号": [], for v in head_result.values():
"数量": [], set_k.add(v)
"品牌": [] for v in model_result.values():
} set_k.add(v)
for k, v in head_std_result.items(): vote_count = {k: [] for k in set_k}
for k, v in head_result.items():
vote_count[v].append(k) vote_count[v].append(k)
for k, v in model_std_result.items(): for k, v in model_result.items():
vote_count[v].append(k) vote_count[v].append(k)
# 此处并没有处理表头和模型两者预测冲突的情况, 但因为repeat_max方法是稳定的,所以冲突后会取第一个也就是表头预测的结果 # 此处并没有处理表头和模型两者预测冲突的情况, 但因为repeat_max方法是稳定的,所以冲突后会取第一个也就是表头预测的结果, 之后按照major参数决定冲突结果
comprehensive_res = {} comprehensive_res = {}
for std_name, col_li in vote_count.items(): for std_name, col_li in vote_count.items():
if col_li: if col_li:
......
...@@ -6,11 +6,20 @@ from static_config import * ...@@ -6,11 +6,20 @@ from static_config import *
def fun(seri): def fun(seri):
li_seri = seri.tolist() li_seri = seri.tolist()
for field in li_seri: for index, field in enumerate(li_seri):
if str(field).lower() in PROB_FIELDS: if str(field).lower() in PROB_FIELDS:
return field, seri.name return field, seri.name
def get_head_row(li: list) -> int:
"""
返回第一行有效数据所在的行
"""
for index, i in enumerate(li):
if str(i).lower() in PROB_FIELDS:
return index + 1
# 取前多少行 # 取前多少行
HEAD_ROW = 7 HEAD_ROW = 7
# 空置率阈值 # 空置率阈值
...@@ -29,6 +38,7 @@ class DicPredict(BasePredictor): ...@@ -29,6 +38,7 @@ class DicPredict(BasePredictor):
li_fie = series.tolist() li_fie = series.tolist()
li_res = [] li_res = []
for field, column_name in li_fie: for field, column_name in li_fie:
print(field)
dic = { dic = {
'std_name': '', 'std_name': '',
'pronoun': field, 'pronoun': field,
...@@ -44,9 +54,19 @@ class DicPredict(BasePredictor): ...@@ -44,9 +54,19 @@ class DicPredict(BasePredictor):
columns = [] columns = []
li_data = [] li_data = []
dict_head_rows = {}
effect_row = 0
for k, v in dict_data.items(): for k, v in dict_data.items():
columns.append(k) columns.append(k)
li_data.append(v) li_data.append(v)
temp_effect_row = get_head_row(v)
if temp_effect_row:
if dict_head_rows.get(temp_effect_row):
dict_head_rows[temp_effect_row] += 1
else:
dict_head_rows[temp_effect_row] = 0
if dict_head_rows:
effect_row = max(dict_head_rows, key=dict_head_rows.get)
df = self.pd.DataFrame(li_data) df = self.pd.DataFrame(li_data)
dft = df.T.head(HEAD_ROW) dft = df.T.head(HEAD_ROW)
...@@ -75,6 +95,7 @@ class DicPredict(BasePredictor): ...@@ -75,6 +95,7 @@ class DicPredict(BasePredictor):
head_id_res = { head_id_res = {
'std_result': std_result, 'std_result': std_result,
'ab_result': ab_result, 'ab_result': ab_result,
'effect_row': effect_row
} }
return head_id_res return head_id_res
...@@ -88,7 +109,10 @@ class DicPredict(BasePredictor): ...@@ -88,7 +109,10 @@ class DicPredict(BasePredictor):
prob_columns = [] prob_columns = []
temp_pre_model_res = {} temp_pre_model_res = {}
ab_result = {} ab_result = {}
# set_head_rows = set()
for k, v in dic_data.items(): for k, v in dic_data.items():
no_null_v = list(filter(lambda x: x != BasePredictor.PLACEHOLDER, v)) no_null_v = list(filter(lambda x: x != BasePredictor.PLACEHOLDER, v))
bol = self.v_chain(v) bol = self.v_chain(v)
if self.is_seq(no_null_v): if self.is_seq(no_null_v):
...@@ -128,7 +152,8 @@ class DicPredict(BasePredictor): ...@@ -128,7 +152,8 @@ class DicPredict(BasePredictor):
temp_pre_model_res[k] = EN_TO_ZH_MAP[result] temp_pre_model_res[k] = EN_TO_ZH_MAP[result]
# 若有多个型号列,则进行不同率的比较, 不同率最高的选为目标列 # 若有多个型号列,则进行不同率的比较, 不同率最高的选为目标列
not_null_dic_data = {k: list(filter(lambda x: x != BasePredictor.PLACEHOLDER, dic_data[k])) for k in prob_columns} not_null_dic_data = {k: list(filter(lambda x: x != BasePredictor.PLACEHOLDER, dic_data[k])) for k in
prob_columns}
# prob_param_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数'] # prob_param_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '参数']
prob_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '型号'] prob_gn_cols = [i for i in temp_pre_model_res if temp_pre_model_res[i] == '型号']
for param_or_gn_col_list in [prob_gn_cols]: for param_or_gn_col_list in [prob_gn_cols]:
...@@ -189,6 +214,7 @@ class DicPredict(BasePredictor): ...@@ -189,6 +214,7 @@ class DicPredict(BasePredictor):
model_id_res = { model_id_res = {
'std_result': temp_pre_model_res, 'std_result': temp_pre_model_res,
'ab_result': ab_result, 'ab_result': ab_result,
'effect_row': 0 #TODO 模型暂时不预测表头
} }
return model_id_res return model_id_res
...@@ -202,16 +228,24 @@ class DicPredict(BasePredictor): ...@@ -202,16 +228,24 @@ class DicPredict(BasePredictor):
self.info.info('表头预测结果: ' + str(head_id_res)) self.info.info('表头预测结果: ' + str(head_id_res))
self.info.info('模型预测结果: ' + str(model_id_res)) self.info.info('模型预测结果: ' + str(model_id_res))
# 表头预测和模型预测最后返回的数据进行综合处理 # 表头预测和模型预测最后返回的标准字段数据进行综合处理
head_std_result = head_id_res.get('std_result') head_std_result = head_id_res.get('std_result')
model_std_result = model_id_res.get('std_result') model_std_result = model_id_res.get('std_result')
comprehensive_res = BasePredictor.get_comprehensive_res(head_std_result, model_std_result) std_comprehensive_res = BasePredictor.get_comprehensive_res(head_std_result, model_std_result)
if comprehensive_res:
# 表头预测和模型预测最后返回的非标准字段数据进行综合处理
head_ab_result = head_id_res.get('ab_result')
model_ab_result = model_id_res.get('ab_result')
ab_comprehensive_res = BasePredictor.get_comprehensive_res(head_ab_result, model_ab_result,
vote_type='abnomal')
if std_comprehensive_res:
res = { res = {
'std_result': comprehensive_res, 'std_result': std_comprehensive_res,
'ab_result': head_id_res['ab_result'] or model_id_res['ab_result'] or {}, 'ab_result': ab_comprehensive_res,
'effect_row': head_id_res.get('effect_row')
} }
self.info.info('综合预测结果: ' + str(comprehensive_res)) self.info.info('综合预测结果: ' + str(res))
return res return res
elif predict_type == 'model': elif predict_type == 'model':
......
...@@ -3,19 +3,32 @@ ...@@ -3,19 +3,32 @@
# 可能的头部字段 # 可能的头部字段
PROB_FIELDS = ["序号", "名称", "规格", "mpn", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注", PROB_FIELDS = ["序号", "名称", "规格", "mpn", "用量(pcs)", "用量", "pcs", "位号", "描述", "值", "数量", "封装", "类别", "a面位置", "b面位置", "备注",
"需求数量", "售价", "封装", "封装规格", '型号', '参数', '数量', '品牌', '型号', '类型', "quantity", "qty", "buy qty", "需求数量", "售价", "封装", "封装规格", '型号', '参数', '数量', '品牌', '型号', '类型', "quantity", "qty", "buy qty",
"buy quantity", "需求用量","manufacturer制造商", "manufacturer" "buy quantity", "需求用量","manufacturer制造商", "manufacturer", "description"
"参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面", "designator", "remark", "元器件", "参考品牌", "品牌", "item", "厂商编码", "品牌/厂商", "参考料号", "参考供应商", "top面", "bottom面", "designator", "remark", "元器件",
"标号"] "标号"]
PROB_FIELDS_1 = ['*' + i for i in PROB_FIELDS] PROB_FIELDS_1 = ['*' + i for i in PROB_FIELDS]
PROB_FIELDS.extend(PROB_FIELDS_1) PROB_FIELDS.extend(PROB_FIELDS_1)
# 标准名和代名词的映射 # 标准名和代名词的映射
li_category = ["类别", "分类", "名称", "类别名称", "类型"]
li_param = ["参数", "规格", "描述", "值", "description"]
li_gn = ["型号", "参考料号", "料号", "mpn", "厂商编码", "元器件"]
li_num = ["数量", "用量(pcs)", "用量", "pcs", "quantity", "qty", "buy qty", "buy quantity", "需求用量"]
li_brand = ["品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商", "manufacturer制造商", "manufacturer"]
li_category.extend(['*' + i for i in li_category])
li_param.extend(['*' + i for i in li_param])
li_gn.extend(['*' + i for i in li_gn])
li_num.extend(['*' + i for i in li_num])
li_brand.extend(['*' + i for i in li_brand])
STD_FIELDS_MAP = { STD_FIELDS_MAP = {
"类别": ["类别", "分类", "名称", "类别名称", "类型"], "类别": li_category,
"参数": ["参数", "*参数", "规格", "描述", "值"], "参数": li_param,
"型号": ["型号", "*型号", "参考料号", "料号", "mpn", "厂商编码", "元器件"], "型号": li_gn,
"数量": ["数量", "*数量", "用量(pcs)", "用量", "pcs", "quantity", "qty", "buy qty", "buy quantity", "需求用量"], "数量": li_num,
"品牌": ["品牌", "*品牌", "品牌/厂商", "参考品牌", "参考供应商", "厂商", "参考供应商", "参考厂商", "manufacturer制造商", "manufacturer"]} "品牌": li_brand
}
# 必须返回也必须验证的标准字段 # 必须返回也必须验证的标准字段
MUST_STD_FIELDS = ['参数', '数量'] MUST_STD_FIELDS = ['参数', '数量']
...@@ -46,4 +59,4 @@ DIC_PRIORITY = { ...@@ -46,4 +59,4 @@ DIC_PRIORITY = {
'数量': 7, '数量': 7,
'序号': 6, '序号': 6,
'其他': 5 '其他': 5
} }
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment