diff --git a/apps/ai/src/server/trpc/route/algorithm/index.ts b/apps/ai/src/server/trpc/route/algorithm/index.ts index a0a27d56fc3..ae537b458e5 100644 --- a/apps/ai/src/server/trpc/route/algorithm/index.ts +++ b/apps/ai/src/server/trpc/route/algorithm/index.ts @@ -16,6 +16,10 @@ import { createAlgorithm, deleteAlgorithm, getAlgorithms, updateAlgorithm } from import { copyPublicAlgorithmVersion, createAlgorithmVersion, deleteAlgorithmVersion, getAlgorithmVersions, shareAlgorithmVersion, unShareAlgorithmVersion, updateAlgorithmVersion } from "./algorithmVersion"; +import { uCreateAlgorithm, uDeleteAlgorithm, uGetAlgorithms, uUpdateAlgorithm } from "./ualgorithm"; +import { uCopyPublicAlgorithmVersion, uCreateAlgorithmVersion, uDeleteAlgorithmVersion, + uGetAlgorithmVersions, uShareAlgorithmVersion, + uUnShareAlgorithmVersion, uUpdateAlgorithmVersion } from "./ualgorithmVersion"; export const algorithm = router({ getAlgorithms, @@ -29,4 +33,16 @@ export const algorithm = router({ deleteAlgorithmVersion, shareAlgorithmVersion, unShareAlgorithmVersion, + + uGetAlgorithms, + uCreateAlgorithm, + uDeleteAlgorithm, + uUpdateAlgorithm, + uCopyPublicAlgorithmVersion, + uCreateAlgorithmVersion, + uDeleteAlgorithmVersion, + uGetAlgorithmVersions, + uShareAlgorithmVersion, + uUnShareAlgorithmVersion, + uUpdateAlgorithmVersion, }); diff --git a/apps/ai/src/server/trpc/route/algorithm/ualgorithm.ts b/apps/ai/src/server/trpc/route/algorithm/ualgorithm.ts new file mode 100644 index 00000000000..019dbdb6b57 --- /dev/null +++ b/apps/ai/src/server/trpc/route/algorithm/ualgorithm.ts @@ -0,0 +1,347 @@ +/** + * Copyright (c) 2022 Peking University and Peking University Institute for Computing and Digital Economy + * SCOW is licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +import { TRPCError } from "@trpc/server"; +import { basename, dirname, join } from "path"; +import { Algorithm, Framework } from "src/server/entities/Algorithm"; +import { AlgorithmVersion, SharedStatus } from "src/server/entities/AlgorithmVersion"; +import { procedure } from "src/server/trpc/procedure/base"; +import { clusterNotFound } from "src/server/utils/errors"; +import { forkEntityManager } from "src/server/utils/getOrm"; +import { paginationProps } from "src/server/utils/orm"; +import { paginationSchema } from "src/server/utils/pagination"; +import { getArrayResponseSchema, getObjectResponseSchema } from "src/server/utils/schema"; +import { getUpdatedSharedPath, unShareFileOrDir } from "src/server/utils/share"; +import { getClusterLoginNode } from "src/server/utils/ssh"; +import { z } from "zod"; + +import { booleanQueryParam, clusterExist } from "../utils"; + +const AlgorithmSchema = z.object({ + id:z.number(), + name:z.string(), + owner:z.string(), + framework:z.nativeEnum(Framework), + isShared:z.boolean(), + description:z.string().optional(), + clusterId:z.string(), + createTime:z.string().optional(), + versions:z.array(z.string()), +}); + + +export const uGetAlgorithms = procedure + .meta({ + openapi: { + method: "GET", + path: "/v1/algorithms", + tags: ["algorithm"], + summary: "Get Algorithms", + }, + }) + .input(z.object({ + ...paginationSchema.shape, + framework: z.nativeEnum(Framework).optional(), + nameOrDesc: z.string().optional(), + clusterId: z.string().optional(), + isPublic: booleanQueryParam().optional(), + })) + .output(getArrayResponseSchema(AlgorithmSchema)) + .query(async ({ input, ctx: { user } }) => { + try { + const em = await forkEntityManager(); + const { page, pageSize, framework, nameOrDesc, clusterId, isPublic } = input; + + const [items, count] = await em.findAndCount(Algorithm, { + $and:[ + isPublic ? { isShared:true } : + { owner: user!.identityId }, + framework ? { framework } : {}, + clusterId ? { clusterId } : {}, + nameOrDesc ? + { $or: [ + { name: { $like: `%${nameOrDesc}%` } }, + { description: { $like: `%${nameOrDesc}%` } }, + ]} : {}, + ], + }, + { + ...paginationProps(page, pageSize), + populate: ["versions.sharedStatus", "versions.privatePath"], + orderBy: { createTime: "desc" }, + }); + + return { + respCode: 200, + respMessage: "Get Algorithms Success", + respBody: { + data: items.map((x) => { + return { + id:x.id, + name:x.name, + owner:x.owner, + framework:x.framework, + isShared:x.isShared, + description:x.description, + clusterId:x.clusterId, + createTime:x.createTime ? x.createTime.toISOString() : undefined, + versions: isPublic ? + x.versions.filter((x) => (x.sharedStatus === SharedStatus.SHARED)).map((y) => y.path) + : x.versions.map((y) => y.privatePath), + }; }), + total: count, + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Get Algorithms Fail", + respError: e?.message, + respBody: { + data: [], + total: 0, + }, + }; + } + }); + + +export const uCreateAlgorithm = procedure + .meta({ + openapi: { + method: "POST", + path: "/v1/algorithms", + tags: ["algorithm"], + summary: "create a new algorithms", + }, + }) + .input(z.object({ + name: z.string(), + framework: z.nativeEnum(Framework), + clusterId: z.string(), + description: z.string().optional(), + })) + .output(getObjectResponseSchema(z.object({ + id: z.number(), + }))) + .mutation(async ({ input, ctx: { user } }) => { + try { + if (!clusterExist(input.clusterId)) { + throw new TRPCError({ + code: "BAD_REQUEST", + message: `Cluster id ${input.clusterId} does not exist.`, + }); + } + + const em = await forkEntityManager(); + const algorithmExist = await em.findOne(Algorithm, { name:input.name, owner: user!.identityId }); + if (algorithmExist) { + throw new TRPCError({ + code: "CONFLICT", + message: `Algorithm name ${input.name} already exist`, + }); + } + + const algorithm = new Algorithm({ ...input, owner: user!.identityId }); + await em.persistAndFlush(algorithm); + return { + respCode: 200, + respMessage: "Create a New Algorithm Success", + respBody: { + data: { + id: algorithm.id, + }, + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Create a New Algorithm Fail", + respError: e?.message, + respBody: {}, + }; + } + + + }); + +export const uUpdateAlgorithm = procedure + .meta({ + openapi: { + method: "PUT", + path: "/v1/algorithms/{id}", + tags: ["algorithm"], + summary: "Update a Algorithm", + }, + }) + .input(z.object({ + id:z.number(), + name: z.string(), + framework: z.nativeEnum(Framework), + description: z.string().optional(), + })) + .output(getObjectResponseSchema(z.void())) + .mutation(async ({ input:{ name, framework, description, id }, ctx: { user } }) => { + + try { + const em = await forkEntityManager(); + const algorithm = await em.findOne(Algorithm, { id }); + + if (!algorithm) { + throw new TRPCError({ + code: "NOT_FOUND", + message: `Algorithm (id:${id}) is not found`, + }); + } + + const algorithmExist = await em.findOne(Algorithm, { name, + owner: user.identityId, + }); + + if (algorithmExist && algorithmExist !== algorithm) { + throw new TRPCError({ + code: "CONFLICT", + message: `Algorithm name ${name} already exist`, + }); + } + + if (algorithm.owner !== user.identityId) + throw new TRPCError({ code: "FORBIDDEN", message: `Algorithm ${id} not accessible` }); + + // 存在正在分享或正在取消分享的算法版本,则不可更新名称 + const changingVersions = await em.find(AlgorithmVersion, { algorithm, + $or: [ + { sharedStatus: SharedStatus.SHARING }, + { sharedStatus: SharedStatus.UNSHARING }, + ]}, + ); + if (changingVersions.length > 0) { + throw new TRPCError({ + code: "PRECONDITION_FAILED", + message: `Unfinished processing of algorithm ${id} exists`, + }); + } + + // 如果是已分享的算法且名称发生变化,则变更共享路径下的此算法名称为新名称 + if (algorithm.isShared && name !== algorithm.name) { + + const sharedVersions = await em.find(AlgorithmVersion, { algorithm, sharedStatus: SharedStatus.SHARED }); + const oldPath = dirname(dirname(sharedVersions[0].path)); + + // 获取更新后的当前算法的共享路径名称 + const newAlgorithmSharedPath = await getUpdatedSharedPath({ + clusterId: algorithm.clusterId, + newName: name, + oldPath, + }); + + // 更新已分享的版本的共享文件夹地址 + sharedVersions.map((v) => { + const baseFolderName = basename(v.path); + const newPath = join(newAlgorithmSharedPath, v.versionName, baseFolderName); + + v.path = newPath; + }); + } + + algorithm.framework = framework; + algorithm.name = name; + algorithm.description = description; + + await em.flush(); + return { + respCode: 200, + respMessage: "Update a Algorithm Success", + respBody: { + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Update a Algorithm Fail", + respError: e?.message, + respBody: {}, + }; + } + + }); + +export const uDeleteAlgorithm = procedure + .meta({ + openapi: { + method: "DELETE", + path: "/v1/algorithms/{id}", + tags: ["algorithm"], + summary: "Delete a Algorithm", + }, + }) + .input(z.object({ id: z.number() })) + .output(getObjectResponseSchema(z.void())) + .mutation(async ({ input:{ id }, ctx:{ user } }) => { + try { + const em = await forkEntityManager(); + const algorithm = await em.findOne(Algorithm, { id }); + + if (!algorithm) { + throw new TRPCError({ + code: "NOT_FOUND", + message: `Algorithm (id:${id}) is not found`, + }); + } + + if (algorithm.owner !== user.identityId) + throw new TRPCError({ code: "FORBIDDEN", message: `Algorithm (id:${id}) not accessible` }); + + const algorithmVersions = await em.find(AlgorithmVersion, { algorithm }); + + const sharingVersions = algorithmVersions.filter( + (v) => (v.sharedStatus === SharedStatus.SHARING || v.sharedStatus === SharedStatus.UNSHARING)); + + // 有正在分享中或取消分享中的版本,则不可删除 + if (sharingVersions.length > 0) { + throw new TRPCError( + { code: "PRECONDITION_FAILED", + message: `There is an algorithm version being shared or unshared of algorithm ${id}` }); + } + + const sharedVersions = algorithmVersions.filter((v) => (v.sharedStatus === SharedStatus.SHARED)); + + // 获取此算法的共享的算法绝对路径 + if (sharedVersions.length > 0) { + const sharedDatasetPath = dirname(dirname(sharedVersions[0].path)); + + const host = getClusterLoginNode(algorithm.clusterId); + if (!host) { throw clusterNotFound(algorithm.clusterId); } + + await unShareFileOrDir({ + host, + sharedPath: sharedDatasetPath, + }); + } + + await em.removeAndFlush([...algorithmVersions, algorithm]); + + return { + respCode: 200, + respMessage: "Delete a Algorithm Success", + respBody: {}, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Delete a Algorithm Fail", + respError: e?.message, + respBody: {}, + }; + } + + }); diff --git a/apps/ai/src/server/trpc/route/algorithm/ualgorithmVersion.ts b/apps/ai/src/server/trpc/route/algorithm/ualgorithmVersion.ts new file mode 100644 index 00000000000..f792c44b01f --- /dev/null +++ b/apps/ai/src/server/trpc/route/algorithm/ualgorithmVersion.ts @@ -0,0 +1,662 @@ +/** + * Copyright (c) 2022 Peking University and Peking University Institute for Computing and Digital Economy + * SCOW is licensed under Mulan PSL v2. + * You can use this software according to the terms and conditions of the Mulan PSL v2. + * You may obtain a copy of Mulan PSL v2 at: + * http://license.coscl.org.cn/MulanPSL2 + * THIS SOFTWARE IS PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, + * EITHER EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, + * MERCHANTABILITY OR FIT FOR A PARTICULAR PURPOSE. + * See the Mulan PSL v2 for more details. + */ + +import { getUserHomedir, sftpExists } from "@scow/lib-ssh"; +import { TRPCError } from "@trpc/server"; +import path, { basename, dirname, join } from "path"; +import { Algorithm } from "src/server/entities/Algorithm"; +import { AlgorithmVersion, SharedStatus } from "src/server/entities/AlgorithmVersion"; +import { procedure } from "src/server/trpc/procedure/base"; +import { checkCopyFilePath, checkCreateResourcePath } from "src/server/utils/checkPathPermission"; +import { chmod } from "src/server/utils/chmod"; +import { copyFile } from "src/server/utils/copyFile"; +import { clusterNotFound } from "src/server/utils/errors"; +import { forkEntityManager } from "src/server/utils/getOrm"; +import { logger } from "src/server/utils/logger"; +import { paginationProps } from "src/server/utils/orm"; +import { paginationSchema } from "src/server/utils/pagination"; +import { getArrayResponseSchema, getObjectResponseSchema } from "src/server/utils/schema"; +import { checkSharePermission, getUpdatedSharedPath, SHARED_TARGET, + shareFileOrDir, unShareFileOrDir } from "src/server/utils/share"; +import { getClusterLoginNode, sshConnect } from "src/server/utils/ssh"; +import { z } from "zod"; + +import { booleanQueryParam } from "../utils"; + +export const uGetAlgorithmVersions = procedure + .meta({ + openapi: { + method: "GET", + path: "/v1/algorithms/{algorithmId}/versions", + tags: ["algorithmVersion"], + summary: "Get AlgorithmVersions", + }, + }) + .input(z.object({ + ...paginationSchema.shape, + algorithmId: z.number(), + isPublic:booleanQueryParam().optional(), + })) + .output( + getArrayResponseSchema(z.object({ + id:z.number(), + versionName:z.string(), + versionDescription:z.string().optional(), + path:z.string(), + privatePath: z.string(), + sharedStatus:z.nativeEnum(SharedStatus), + createTime:z.string().optional(), + }))) + .query(async ({ input:{ algorithmId, page, pageSize, isPublic } }) => { + + try { + const em = await forkEntityManager(); + const [items, count] = await em.findAndCount(AlgorithmVersion, + { + algorithm: algorithmId, + ...isPublic ? { sharedStatus:SharedStatus.SHARED } : {}, + }, + { + populate: ["algorithm"], + ...paginationProps(page, pageSize), + orderBy: { createTime: "desc" }, + }); + + return { + respCode: 200, + respMessage: "Get AlgorithmVersions Success", + respBody: { + data: items.map((x) => { + return { + id:x.id, + versionName:x.versionName, + versionDescription:x.versionDescription, + sharedStatus:x.sharedStatus, + createTime:x.createTime ? x.createTime.toISOString() : undefined, + path:x.path, + privatePath: x.privatePath, + }; + }), + total: count, + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Get AlgorithmVersions Fail", + respError: e?.message, + respBody: { + data: [], + total: 0, + }, + }; + } + + }); + +export const uCreateAlgorithmVersion = procedure + .meta({ + openapi: { + method: "POST", + path: "/v1/algorithms/{algorithmId}/versions", + tags: ["algorithmVersion"], + summary: "create a new algorithmVersion", + }, + }) + .input(z.object({ + versionName: z.string(), + path: z.string(), + versionDescription: z.string().optional(), + algorithmId: z.number(), + })) + .output(getObjectResponseSchema(z.object({ id: z.number() }))) + .mutation(async ({ input, ctx: { user } }) => { + + try { + const em = await forkEntityManager(); + const algorithm = await em.findOne(Algorithm, { id: input.algorithmId }); + if (!algorithm) + throw new TRPCError({ code: "NOT_FOUND", message: `Algorithm id:${input.algorithmId} not Found` }); + + if (algorithm && algorithm.owner !== user.identityId) + throw new TRPCError({ code: "CONFLICT", + message: `Algorithm id:${input.algorithmId} is belonged to the other user`, + }); + + const algorithmVersionExist = await em.findOne(AlgorithmVersion, + { versionName: input.versionName, algorithm }); + if (algorithmVersionExist) + throw new TRPCError({ code: "CONFLICT", message: `AlgorithmVersion name:${input.versionName} already exist` }); + + // 检查目录是否存在 + const host = getClusterLoginNode(algorithm.clusterId); + + if (!host) { throw clusterNotFound(algorithm.clusterId); } + + await checkCreateResourcePath({ host, userIdentityId: user.identityId, toPath: input.path }); + + await sshConnect(host, user.identityId, logger, async (ssh) => { + const sftp = await ssh.requestSFTP(); + + if (!(await sftpExists(sftp, input.path))) { + throw new TRPCError({ code: "BAD_REQUEST", message: `${input.path} does not exists` }); + } + }); + + const algorithmVersion = new AlgorithmVersion({ ...input, privatePath: input.path, algorithm: algorithm }); + await em.persistAndFlush(algorithmVersion); + return { + respCode: 200, + respMessage: "Create a New algorithmVersion Success", + respBody: { + data: { + id: algorithmVersion.id, + }, + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Create a New algorithmVersion Fail", + respBody: { + }, + }; + } + }); + +export const uUpdateAlgorithmVersion = procedure + .meta({ + openapi: { + method: "PUT", + path: "/v1/algorithms/{algorithmId}/versions/{algorithmVersionId}", + tags: ["algorithmVersion"], + summary: "Update a AlgorithmVersion", + }, + }) + .input(z.object({ + algorithmId: z.number(), + algorithmVersionId: z.number(), + versionName: z.string(), + versionDescription: z.string().optional(), + })) + .output(getObjectResponseSchema(z.object({ id: z.number() }))) + .mutation(async ({ input, ctx: { user } }) => { + try { + const em = await forkEntityManager(); + + const algorithm = await em.findOne(Algorithm, { id: input.algorithmId }); + if (!algorithm) throw new TRPCError({ + code: "NOT_FOUND", message: `Algorithm id:${input.algorithmId} not found` }); + + if (algorithm.owner !== user.identityId) + throw new TRPCError({ code: "FORBIDDEN", message: `Algorithm ${input.algorithmId} not accessible` }); + + const algorithmVersion = await em.findOne(AlgorithmVersion, { id: input.algorithmVersionId }); + if (!algorithmVersion) + throw new TRPCError({ + code: "NOT_FOUND", message: `AlgorithmVersion id:${input.algorithmVersionId} not found` }); + + const algorithmVersionExist = await em.findOne(AlgorithmVersion, + { versionName: input.versionName, algorithm }); + if (algorithmVersionExist && algorithmVersionExist !== algorithmVersion) { + throw new TRPCError({ code: "CONFLICT", message: `AlgorithmVersion name:${input.versionName} already exist` }); + } + + if (algorithmVersion.sharedStatus === SharedStatus.SHARING || + algorithmVersion.sharedStatus === SharedStatus.UNSHARING) { + throw new TRPCError({ + code: "PRECONDITION_FAILED", + message: `Unfinished processing of algorithmVersion ${input.algorithmVersionId} exists`, + }); + } + + const needUpdateSharedPath = algorithmVersion.sharedStatus === SharedStatus.SHARED + && input.versionName !== algorithmVersion.versionName; + + // 更新已分享目录下的版本路径名称 + if (needUpdateSharedPath) { + // 获取更新后的已分享版本路径 + const newVersionSharedPath = await getUpdatedSharedPath({ + clusterId: algorithm.clusterId, + newName: input.versionName, + oldPath: dirname(algorithmVersion.path), + }); + + const baseFolderName = basename(algorithmVersion.path); + + algorithmVersion.path = join(newVersionSharedPath, baseFolderName); + } + + algorithmVersion.versionName = input.versionName; + algorithmVersion.versionDescription = input.versionDescription; + + await em.flush(); + return { + respCode: 200, + respMessage: "Update a AlgorithmVersion Success", + respBody: { + data: { + id: algorithmVersion.id, + }, + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Update a AlgorithmVersion Fail", + respError: e?.message, + respBody: {}, + }; + } + }); + +export const uDeleteAlgorithmVersion = procedure + .meta({ + openapi: { + method: "DELETE", + path: "/v1/algorithms/{algorithmId}/versions/{algorithmVersionId}", + tags: ["algorithmVersion"], + summary: "Delete a AlgorithmVersion", + }, + }) + .input(z.object({ algorithmVersionId: z.number(), algorithmId:z.number() })) + .output(getObjectResponseSchema(z.void())) + .mutation(async ({ input:{ algorithmVersionId, algorithmId }, ctx: { user } }) => { + try { + const em = await forkEntityManager(); + const algorithmVersion = await em.findOne(AlgorithmVersion, { id:algorithmVersionId }); + if (!algorithmVersion) throw new Error(`AlgorithmVersion id:${algorithmVersionId} not found`); + + const algorithm = await em.findOne(Algorithm, { id: algorithmId }, + { populate: ["versions.sharedStatus"]}); + if (!algorithm) + throw new TRPCError({ code: "NOT_FOUND", message: `Algorithm id:${algorithmId} is not found` }); + + if (algorithm.owner !== user.identityId) + throw new TRPCError({ code: "FORBIDDEN", message: `Algorithm id:${algorithmId} is not accessible` }); + + // 正在分享中或取消分享中的版本,不可删除 + if (algorithmVersion.sharedStatus === SharedStatus.SHARING + || algorithmVersion.sharedStatus === SharedStatus.UNSHARING) { + throw new TRPCError( + { code: "PRECONDITION_FAILED", + message: `AlgorithmVersion (id:${algorithmVersionId}) is currently being shared or unshared` }); + } + + // 如果是已分享的数据集版本,则删除分享 + if (algorithmVersion.sharedStatus === SharedStatus.SHARED) { + + try { + const host = getClusterLoginNode(algorithm.clusterId); + if (!host) { throw clusterNotFound(algorithm.clusterId); } + + await sshConnect(host, user.identityId, logger, async (ssh) => { + await checkSharePermission({ + ssh, + logger, + sourcePath: algorithmVersion.privatePath, + userId: user.identityId, + }); + }); + + const pathToUnshare + = algorithm.versions.filter((v) => + (v.id !== algorithmVersionId && v.sharedStatus === SharedStatus.SHARED)).length > 0 ? + // 除了此版本以外仍有其他已分享的版本则取消分享当前版本 + dirname(algorithmVersion.path) + // 除了此版本以外没有其他已分享的版本则取消分享整个算法 + : dirname(dirname(algorithmVersion.path)); + await unShareFileOrDir({ + host, + sharedPath: pathToUnshare, + }); + } catch (e) { + logger.error(`ssh failure occurred when unshare + algorithmVersion ${algorithmVersionId} of algorithm ${algorithmId}`, e); + } + + + algorithm.isShared = algorithm.versions.filter((v) => (v.sharedStatus === SharedStatus.SHARED)).length > 1 + ? true : false; + await em.flush(); + } + + em.remove(algorithmVersion); + await em.flush(); + return { + respCode: 200, + respMessage: "Delete a AlgorithmVersion Success", + respBody: {}, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Delete a AlgorithmVersion Fail", + respError: e?.message, + respBody: {}, + }; + } + + }); + +export const uShareAlgorithmVersion = procedure + .meta({ + openapi: { + method: "POST", + path: "/v1/algorithms/{algorithmId}/versions/{algorithmVersionId}/share", + tags: ["algorithmVersion"], + summary: "Share a AlgorithmVersion", + }, + }) + .input(z.object({ + algorithmId: z.number(), + algorithmVersionId: z.number(), + sourceFilePath: z.string(), + })) + .output(getObjectResponseSchema(z.void())) + .mutation(async ({ input:{ algorithmId, algorithmVersionId, sourceFilePath }, ctx: { user } }) => { + + try { + const em = await forkEntityManager(); + const algorithmVersion = await em.findOne(AlgorithmVersion, { id: algorithmVersionId }); + if (!algorithmVersion) + throw new TRPCError({ code: "NOT_FOUND", message: `AlgorithmVersion id:${algorithmId} not found` }); + + if (algorithmVersion.sharedStatus === SharedStatus.SHARED) + throw new TRPCError({ code: "CONFLICT", message: `AlgorithmVersion id:${algorithmId} is already shared` }); + + const algorithm = await em.findOne(Algorithm, { id: algorithmId }); + if (!algorithm) + throw new TRPCError({ code: "NOT_FOUND", message: `Algorithm id:${algorithmId} not found` }); + + if (algorithm.owner !== user.identityId) + throw new TRPCError({ code: "FORBIDDEN", message: `Algorithm id:${algorithmId} not accessible` }); + + const host = getClusterLoginNode(algorithm.clusterId); + if (!host) { throw clusterNotFound(algorithm.clusterId); } + + const homeTopDir = await sshConnect(host, user.identityId, logger, async (ssh) => { + // 确认是否具有分享权限 + await checkSharePermission({ ssh, logger, sourcePath: sourceFilePath, userId: user.identityId }); + // 获取分享路径的上级路径 + const userHomeDir = await getUserHomedir(ssh, user.identityId, logger); + return dirname(dirname(userHomeDir)); + }); + + algorithmVersion.sharedStatus = SharedStatus.SHARING; + em.persist([algorithmVersion]); + await em.flush(); + + const successCallback = async (targetFullPath: string) => { + const em = await forkEntityManager(); + + const algorithmVersion = await em.findOne(AlgorithmVersion, { id: algorithmVersionId }); + if (!algorithmVersion) + throw new TRPCError({ code: "NOT_FOUND", message: `AlgorithmVersion id:${algorithmId} not found` }); + + const algorithm = await em.findOne(Algorithm, { id: algorithmId }); + if (!algorithm) + throw new TRPCError({ code: "NOT_FOUND", message: `Algorithm id:${algorithmId} not found` }); + + const versionPath = join(targetFullPath, path.basename(sourceFilePath)); + algorithmVersion.sharedStatus = SharedStatus.SHARED; + algorithmVersion.path = versionPath; + if (!algorithm.isShared) { algorithm.isShared = true; }; + + await em.persistAndFlush([algorithmVersion, algorithm]); + }; + + const failureCallback = async () => { + const em = await forkEntityManager(); + + const algorithmVersion = await em.findOne(AlgorithmVersion, { id: algorithmVersionId }); + if (!algorithmVersion) + throw new TRPCError({ code: "NOT_FOUND", message: `AlgorithmVersion id:${algorithmId} not found` }); + + algorithmVersion.sharedStatus = SharedStatus.UNSHARED; + await em.persistAndFlush([algorithmVersion]); + }; + + shareFileOrDir({ + clusterId: algorithm.clusterId, + sourceFilePath, + userId: user.identityId, + sharedTarget: SHARED_TARGET.ALGORITHM, + targetName: algorithm.name, + targetSubName: algorithmVersion.versionName, + homeTopDir, + }, successCallback, failureCallback); + + return { + respCode: 200, + respMessage: "Share a AlgorithmVersion Success", + respBody: { + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Share a AlgorithmVersion Fail", + respError: e?.message, + respBody: {}, + }; + } + + }); + +export const uUnShareAlgorithmVersion = procedure + .meta({ + openapi: { + method: "DELETE", + path: "/v1/algorithms/{algorithmId}/versions/{algorithmVersionId}/share", + tags: ["algorithmVersion"], + summary: "Unshare a AlgorithmVersion", + }, + }) + .input(z.object({ + algorithmVersionId: z.number(), + algorithmId: z.number(), + })) + .output(getObjectResponseSchema(z.void())) + .mutation(async ({ input:{ algorithmVersionId, algorithmId }, ctx: { user } }) => { + try { + const em = await forkEntityManager(); + const algorithmVersion = await em.findOne(AlgorithmVersion, { id: algorithmVersionId }); + if (!algorithmVersion) + throw new TRPCError({ code: "NOT_FOUND", message: `AlgorithmVersion id:${algorithmVersionId} not found` }); + + if (algorithmVersion.sharedStatus === SharedStatus.UNSHARED) + throw new TRPCError({ + code: "CONFLICT", + message: `AlgorithmVersion id:${algorithmVersionId} is already unShared`, + }); + + const algorithm = await em.findOne(Algorithm, { id: algorithmId }, { + populate: ["versions.sharedStatus"], + }); + if (!algorithm) + throw new TRPCError({ code: "NOT_FOUND", message: `Algorithm id:${algorithmId} not found` }); + + if (algorithm.owner !== user.identityId) + throw new TRPCError({ code: "FORBIDDEN", message: `Algorithm id:${algorithmId} not accessible` }); + + const host = getClusterLoginNode(algorithm.clusterId); + if (!host) { throw clusterNotFound(algorithm.clusterId); } + + await sshConnect(host, user.identityId, logger, async (ssh) => { + await checkSharePermission({ + ssh, + logger, + sourcePath: algorithmVersion.privatePath, + userId: user.identityId, + }); + }); + + algorithmVersion.sharedStatus = SharedStatus.UNSHARING; + em.persist([algorithmVersion]); + await em.flush(); + + const successCallback = async () => { + const em = await forkEntityManager(); + + const algorithmVersion = await em.findOne(AlgorithmVersion, { id: algorithmVersionId }); + if (!algorithmVersion) + throw new TRPCError({ code: "NOT_FOUND", message: `AlgorithmVersion id:${algorithmId} not found` }); + + const algorithm = await em.findOne(Algorithm, { id: algorithmId }, { + populate: ["versions.sharedStatus"], + }); + if (!algorithm) + throw new TRPCError({ code: "NOT_FOUND", message: `Algorithm id:${algorithmId} not found` }); + + algorithmVersion.sharedStatus = SharedStatus.UNSHARED; + algorithmVersion.path = algorithmVersion.privatePath; + algorithm.isShared = algorithm.versions.filter((v) => (v.sharedStatus === SharedStatus.SHARED)).length > 0 + ? true : false; + + await em.persistAndFlush([algorithmVersion, algorithm]); + }; + + const failureCallback = async () => { + const em = await forkEntityManager(); + + const algorithmVersion = await em.findOne(AlgorithmVersion, { id: algorithmVersionId }); + if (!algorithmVersion) + throw new TRPCError({ code: "NOT_FOUND", message: `AlgorithmVersion id:${algorithmId} not found` }); + + algorithmVersion.sharedStatus = SharedStatus.SHARED; + await em.persistAndFlush([algorithmVersion]); + }; + + unShareFileOrDir({ + host, + sharedPath: algorithm.versions.filter((v) => (v.sharedStatus === SharedStatus.SHARED)).length > 0 ? + // 如果还有其他的已分享版本则只取消此版本的分享 + dirname(algorithmVersion.path) + // 如果没有其他的已分享版本则取消整个算法的分享 + : dirname(dirname(algorithmVersion.path)), + }, successCallback, failureCallback); + + return { + respCode: 200, + respMessage: "Unshare a AlgorithmVersion Success", + respBody: { + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Unshare a AlgorithmVersion Fail", + respError: e?.message, + respBody: {}, + }; + } + + }); + +export const uCopyPublicAlgorithmVersion = procedure + .meta({ + openapi: { + method: "POST", + path: "/v1/algorithms/{algorithmId}/versions/{algorithmVersionId}/copy", + tags: ["algorithmVersion"], + summary: "Copy a Public Algorithm Version", + }, + }) + .input(z.object({ + algorithmId: z.number(), + algorithmVersionId: z.number(), + algorithmName: z.string(), + versionName: z.string(), + versionDescription: z.string(), + path: z.string(), + })) + .output(getObjectResponseSchema(z.object({ success: z.boolean() }))) + .mutation(async ({ input, ctx: { user } }) => { + try { + const em = await forkEntityManager(); + + // 1. 检查算法版本是否为公开版本 + const algorithmVersion = await em.findOne(AlgorithmVersion, + { id: input.algorithmVersionId, sharedStatus: SharedStatus.SHARED }, + { populate: ["algorithm"]}); + + if (!algorithmVersion) { + throw new TRPCError({ + code: "NOT_FOUND", + message: `Algorithm Version ${input.algorithmVersionId} does not exist or is not public`, + }); + } + // 2. 检查该用户是否已有同名算法 + const algorithm = await em.findOne(Algorithm, { name: input.algorithmName, owner: user.identityId }); + if (algorithm) { + throw new TRPCError({ + code: "CONFLICT", + message: `An algorithm with the same name as ${input.algorithmName} already exists`, + }); + } + + // 3. 检查用户是否可以将源算法拷贝至目标目录 + const host = getClusterLoginNode(algorithmVersion.algorithm.$.clusterId); + + if (!host) { throw clusterNotFound(algorithmVersion.algorithm.$.clusterId); } + + await checkCopyFilePath({ host, userIdentityId: user.identityId, + toPath: input.path, fileName: path.basename(algorithmVersion.path) }); + + // 4. 写入数据 + const newAlgorithm = new Algorithm({ + name: input.algorithmName, + owner: user.identityId, + framework: algorithmVersion.algorithm.$.framework, + description: algorithmVersion.algorithm.$.description, + clusterId: algorithmVersion.algorithm.$.clusterId, + }); + + const newAlgorithmVersion = new AlgorithmVersion({ + versionName: input.versionName, + versionDescription: input.versionDescription, + path: input.path, + privatePath: input.path, + algorithm: newAlgorithm, + }); + + try { + await copyFile({ host, userIdentityId: user.identityId, + fromPath: algorithmVersion.path, toPath: input.path }); + // 递归修改文件权限和拥有者 + await chmod({ host, userIdentityId: "root", permission: "750", path: input.path }); + await em.persistAndFlush([newAlgorithm, newAlgorithmVersion]); + } catch (err) { + console.log(err); + throw new TRPCError({ + code: "INTERNAL_SERVER_ERROR", + message: `Copy Error ${err}`, + }); + } + return { + respCode: 200, + respMessage: "Copy a Public Algorithm Version Success", + respBody: { + data: { success: true }, + }, + }; + } catch (e: any) { + return { + respCode: 400, + respMessage: "Copy a Public Algorithm Version Fail", + respError: e?.message, + respBody: {}, + }; + } + + });