Skip to content

Commit

Permalink
feat: DIA-1715: VertexAI Gemini model support (#6865)
Browse files Browse the repository at this point in the history
Co-authored-by: hakan458 <[email protected]>
Co-authored-by: Matt Bernstein <[email protected]>
  • Loading branch information
3 people authored Jan 17, 2025
1 parent 4aba7ad commit ddd5ead
Show file tree
Hide file tree
Showing 3 changed files with 106 additions and 0 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
# Generated by Django 5.1.4 on 2025-01-03 20:58

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
(
"ml_model_providers",
"0005_modelproviderconnection_budget_alert_threshold_and_more",
),
]

operations = [
migrations.AddField(
model_name="modelproviderconnection",
name="google_application_credentials",
field=models.TextField(
blank=True,
help_text="The content of GOOGLE_APPLICATION_CREDENTIALS json file",
null=True,
verbose_name="google application credentials",
),
),
migrations.AddField(
model_name="modelproviderconnection",
name="google_location",
field=models.CharField(
blank=True,
help_text="Google project location",
max_length=255,
null=True,
verbose_name="google location",
),
),
migrations.AddField(
model_name="modelproviderconnection",
name="google_project_id",
field=models.CharField(
blank=True,
help_text="Google project ID",
max_length=255,
null=True,
verbose_name="google project id",
),
),
migrations.AlterField(
model_name="modelproviderconnection",
name="provider",
field=models.CharField(
choices=[
("OpenAI", "OpenAI"),
("AzureOpenAI", "AzureOpenAI"),
("VertexAI", "VertexAI"),
("Custom", "Custom"),
],
default="OpenAI",
max_length=255,
),
),
]
16 changes: 16 additions & 0 deletions label_studio/ml_model_providers/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
class ModelProviders(models.TextChoices):
OPENAI = 'OpenAI', _('OpenAI')
AZURE_OPENAI = 'AzureOpenAI', _('AzureOpenAI')
VERTEX_AI = 'VertexAI', _('VertexAI')
CUSTOM = 'Custom', _('Custom')


Expand All @@ -32,6 +33,21 @@ class ModelProviderConnection(models.Model):

endpoint = models.CharField(max_length=512, null=True, blank=True, help_text='Azure OpenAI endpoint')

google_application_credentials = models.TextField(
_('google application credentials'),
null=True,
blank=True,
help_text='The content of GOOGLE_APPLICATION_CREDENTIALS json file',
)

google_project_id = models.CharField(
_('google project id'), max_length=255, null=True, blank=True, help_text='Google project ID'
)

google_location = models.CharField(
_('google location'), max_length=255, null=True, blank=True, help_text='Google project location'
)

cached_available_models = models.CharField(
max_length=4096, null=True, blank=True, help_text='List of available models from the provider'
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Generated by Django 5.1.4 on 2025-01-03 20:58

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("ml_models", "0012_alter_thirdpartymodelversion_provider"),
]

operations = [
migrations.AlterField(
model_name="thirdpartymodelversion",
name="provider",
field=models.CharField(
choices=[
("OpenAI", "OpenAI"),
("AzureOpenAI", "AzureOpenAI"),
("VertexAI", "VertexAI"),
("Custom", "Custom"),
],
default="OpenAI",
help_text="The model provider to use e.g. OpenAI",
max_length=255,
),
),
]

0 comments on commit ddd5ead

Please sign in to comment.